WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View File

@ -1361,7 +1361,6 @@ impl IndexScheduler {
let embedder = Arc::new( let embedder = Arc::new(
Embedder::new(embedder_options.clone()) Embedder::new(embedder_options.clone())
.map_err(meilisearch_types::milli::vector::Error::from) .map_err(meilisearch_types::milli::vector::Error::from)
.map_err(meilisearch_types::milli::UserError::from)
.map_err(meilisearch_types::milli::Error::from)?, .map_err(meilisearch_types::milli::Error::from)?,
); );
{ {

View File

@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
InvalidEmbedder , InvalidRequest , BAD_REQUEST ;
InvalidHybridQuery , InvalidRequest , BAD_REQUEST ;
InvalidIndexLimit , InvalidRequest , BAD_REQUEST ; InvalidIndexLimit , InvalidRequest , BAD_REQUEST ;
InvalidIndexOffset , InvalidRequest , BAD_REQUEST ; InvalidIndexOffset , InvalidRequest , BAD_REQUEST ;
InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ; InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ;
@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ;
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
@ -340,6 +343,7 @@ impl ErrorCode for milli::Error {
} }
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
UserError::MultiplePrimaryKeyCandidatesFound { .. } => { UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
@ -363,6 +367,7 @@ impl ErrorCode for milli::Error {
UserError::InvalidMinTypoWordLenSetting(_, _) => { UserError::InvalidMinTypoWordLenSetting(_, _) => {
Code::InvalidSettingsTypoTolerance Code::InvalidSettingsTypoTolerance
} }
UserError::InvalidEmbedder(_) => Code::InvalidEmbedder,
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
} }
} }

View File

@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats};
use crate::search::{ use crate::search::{
FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO,
}; };
use crate::Opt; use crate::Opt;
@ -586,6 +586,11 @@ pub struct SearchAggregator {
// vector // vector
// The maximum number of floats in a vector request // The maximum number of floats in a vector request
max_vector_size: usize, max_vector_size: usize,
// Whether the semantic ratio passed to a hybrid search equals the default ratio.
semantic_ratio: bool,
// Whether a non-default embedder was specified
embedder: bool,
hybrid: bool,
// every time a search is done, we increment the counter linked to the used settings // every time a search is done, we increment the counter linked to the used settings
matching_strategy: HashMap<String, usize>, matching_strategy: HashMap<String, usize>,
@ -639,6 +644,7 @@ impl SearchAggregator {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@ -712,6 +718,12 @@ impl SearchAggregator {
ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score = *show_ranking_score;
ret.show_ranking_score_details = *show_ranking_score_details; ret.show_ranking_score_details = *show_ranking_score_details;
if let Some(hybrid) = hybrid {
ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO();
ret.embedder = hybrid.embedder.is_some();
ret.hybrid = true;
}
ret ret
} }
@ -765,6 +777,9 @@ impl SearchAggregator {
facets_total_number_of_facets, facets_total_number_of_facets,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = other; } = other;
if self.timestamp.is_none() { if self.timestamp.is_none() {
@ -810,6 +825,9 @@ impl SearchAggregator {
// vector // vector
self.max_vector_size = self.max_vector_size.max(max_vector_size); self.max_vector_size = self.max_vector_size.max(max_vector_size);
self.semantic_ratio |= semantic_ratio;
self.hybrid |= hybrid;
self.embedder |= embedder;
// pagination // pagination
self.max_limit = self.max_limit.max(max_limit); self.max_limit = self.max_limit.max(max_limit);
@ -878,6 +896,9 @@ impl SearchAggregator {
facets_total_number_of_facets, facets_total_number_of_facets,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
semantic_ratio,
embedder,
hybrid,
} = self; } = self;
if total_received == 0 { if total_received == 0 {
@ -917,6 +938,11 @@ impl SearchAggregator {
"vector": { "vector": {
"max_vector_size": max_vector_size, "max_vector_size": max_vector_size,
}, },
"hybrid": {
"enabled": hybrid,
"semantic_ratio": semantic_ratio,
"embedder": embedder,
},
"pagination": { "pagination": {
"max_limit": max_limit, "max_limit": max_limit,
"max_offset": max_offset, "max_offset": max_offset,
@ -1012,6 +1038,7 @@ impl MultiSearchAggregator {
crop_marker: _, crop_marker: _,
matching_strategy: _, matching_strategy: _,
attributes_to_search_on: _, attributes_to_search_on: _,
hybrid: _,
} = query; } = query;
index_uid.as_str() index_uid.as_str()
@ -1158,6 +1185,7 @@ impl FacetSearchAggregator {
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@ -1171,7 +1199,8 @@ impl FacetSearchAggregator {
|| vector.is_some() || vector.is_some()
|| filter.is_some() || filter.is_some()
|| *matching_strategy != MatchingStrategy::default() || *matching_strategy != MatchingStrategy::default()
|| attributes_to_search_on.is_some(); || attributes_to_search_on.is_some()
|| hybrid.is_some();
ret ret
} }

View File

@ -14,9 +14,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator};
use crate::extractors::authentication::policies::*; use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::search::{ use crate::search::{
add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
}; };
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@ -37,6 +37,8 @@ pub struct FacetSearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<Vec<f32>>, pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)] #[deserr(default, error = DeserrJsonError<InvalidSearchFilter>)]
pub filter: Option<Value>, pub filter: Option<Value>,
#[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchMatchingStrategy>, default)]
@ -96,6 +98,7 @@ impl From<FacetSearchQuery> for SearchQuery {
filter, filter,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = value; } = value;
SearchQuery { SearchQuery {
@ -120,6 +123,7 @@ impl From<FacetSearchQuery> for SearchQuery {
matching_strategy, matching_strategy,
vector: vector.map(VectorQuery::Vector), vector: vector.map(VectorQuery::Vector),
attributes_to_search_on, attributes_to_search_on,
hybrid,
} }
} }
} }

View File

@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::VectorQuery; use meilisearch_types::milli::{self, VectorQuery};
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value; use serde_json::Value;
@ -17,9 +17,9 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@ -75,6 +75,10 @@ pub struct SearchQueryGet {
matching_strategy: MatchingStrategy, matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchAttributesToSearchOn>)]
pub attributes_to_search_on: Option<CS<String>>, pub attributes_to_search_on: Option<CS<String>>,
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)]
pub hybrid_embedder: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidHybridQuery>)]
pub hybrid_semantic_ratio: Option<f32>,
} }
impl From<SearchQueryGet> for SearchQuery { impl From<SearchQueryGet> for SearchQuery {
@ -87,6 +91,18 @@ impl From<SearchQueryGet> for SearchQuery {
None => None, None => None,
}; };
let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) {
(None, None) => None,
(None, Some(semantic_ratio)) => Some(HybridQuery { semantic_ratio, embedder: None }),
(Some(embedder), None) => Some(HybridQuery {
semantic_ratio: DEFAULT_SEMANTIC_RATIO(),
embedder: Some(embedder),
}),
(Some(embedder), Some(semantic_ratio)) => {
Some(HybridQuery { semantic_ratio, embedder: Some(embedder) })
}
};
Self { Self {
q: other.q, q: other.q,
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
@ -109,6 +125,7 @@ impl From<SearchQueryGet> for SearchQuery {
crop_marker: other.crop_marker, crop_marker: other.crop_marker,
matching_strategy: other.matching_strategy, matching_strategy: other.matching_strategy,
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
hybrid,
} }
} }
} }
@ -159,6 +176,9 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
@ -213,22 +233,31 @@ pub async fn search_with_post(
pub async fn embed( pub async fn embed(
query: &mut SearchQuery, query: &mut SearchQuery,
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &meilisearch_types::milli::Index, index: &milli::Index,
) -> Result<(), ResponseError> { ) -> Result<(), ResponseError> {
if let Some(VectorQuery::String(prompt)) = query.vector.take() { if let Some(VectorQuery::String(prompt)) = query.vector.take() {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedder = index_scheduler.embedders(embedder_configs)?; let embedder = index_scheduler.embedders(embedder_configs)?;
/// FIXME: add error if no embedder, remove unwrap, support multiple embedders let embedder_name = if let Some(HybridQuery {
semantic_ratio: _,
embedder: Some(embedder),
}) = &query.hybrid
{
embedder
} else {
"default"
};
let embeddings = embedder let embeddings = embedder
.get("default") .get(embedder_name)
.unwrap() .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
.map_err(milli::Error::from)?
.0 .0
.embed(vec![prompt]) .embed(vec![prompt])
.await .await
.map_err(meilisearch_types::milli::vector::Error::from) .map_err(milli::vector::Error::from)
.map_err(meilisearch_types::milli::UserError::from) .map_err(milli::Error::from)?
.map_err(meilisearch_types::milli::Error::from)?
.pop() .pop()
.expect("No vector returned from embedding"); .expect("No vector returned from embedding");

View File

@ -36,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10;
pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string(); pub const DEFAULT_CROP_MARKER: fn() -> String = || "".to_string();
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5;
#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
@ -44,6 +45,8 @@ pub struct SearchQuery {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)] #[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<milli::VectorQuery>, pub vector: Option<milli::VectorQuery>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize, pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@ -84,6 +87,15 @@ pub struct SearchQuery {
pub attributes_to_search_on: Option<Vec<String>>, pub attributes_to_search_on: Option<Vec<String>>,
} }
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
#[deserr(default, error = DeserrJsonError<InvalidSemanticRatio>, default = DEFAULT_SEMANTIC_RATIO())]
pub semantic_ratio: f32,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>,
}
impl SearchQuery { impl SearchQuery {
pub fn is_finite_pagination(&self) -> bool { pub fn is_finite_pagination(&self) -> bool {
self.page.or(self.hits_per_page).is_some() self.page.or(self.hits_per_page).is_some()
@ -103,6 +115,8 @@ pub struct SearchQueryWithIndex {
pub q: Option<String>, pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)] #[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub vector: Option<VectorQuery>, pub vector: Option<VectorQuery>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)] #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
pub offset: usize, pub offset: usize,
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)] #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
@ -168,6 +182,7 @@ impl SearchQueryWithIndex {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
} = self; } = self;
( (
index_uid, index_uid,
@ -193,6 +208,7 @@ impl SearchQueryWithIndex {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid,
// do not use ..Default::default() here, // do not use ..Default::default() here,
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
}, },

View File

@ -63,6 +63,8 @@ pub enum InternalError {
InvalidMatchingWords, InvalidMatchingWords,
#[error(transparent)] #[error(transparent)]
ArroyError(#[from] arroy::Error), ArroyError(#[from] arroy::Error),
#[error(transparent)]
VectorEmbeddingError(#[from] crate::vector::Error),
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
@ -188,8 +190,23 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
#[error(transparent)] #[error(transparent)]
InvalidPrompt(#[from] crate::prompt::error::NewPromptError), InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
#[error("Invalid prompt in for embeddings with name '{0}': {1}")] #[error("Invalid prompt in for embeddings with name '{0}': {1}.")]
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
TooManyEmbedders(usize),
#[error("Cannot find embedder with name {0}.")]
InvalidEmbedder(String),
}
impl From<crate::vector::Error> for Error {
fn from(value: crate::vector::Error) -> Self {
match value.fault() {
FaultSource::User => Error::UserError(value.into()),
FaultSource::Runtime => Error::InternalError(value.into()),
FaultSource::Bug => Error::InternalError(value.into()),
FaultSource::Undecided => Error::InternalError(value.into()),
}
}
} }
impl From<arroy::Error> for Error { impl From<arroy::Error> for Error {

View File

@ -110,7 +110,6 @@ impl Prompt {
}; };
// render template with special object that's OK with `doc.*` and `fields.*` // render template with special object that's OK with `doc.*` and `fields.*`
/// FIXME: doesn't work for nested objects e.g. `doc.a.b`
this.template this.template
.render(&template_checker::TemplateChecker) .render(&template_checker::TemplateChecker)
.map_err(NewPromptError::invalid_fields_in_template)?; .map_err(NewPromptError::invalid_fields_in_template)?;
@ -142,3 +141,80 @@ pub enum PromptFallbackStrategy {
#[default] #[default]
Error, Error,
} }
#[cfg(test)]
mod test {
use super::Prompt;
use crate::error::FaultSource;
use crate::prompt::error::{NewPromptError, NewPromptErrorKind};
#[test]
fn default_template() {
// does not panic
Prompt::default();
}
#[test]
fn empty_template() {
Prompt::new("".into(), None, None).unwrap();
}
#[test]
fn template_ok() {
Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap();
}
#[test]
fn template_syntax() {
assert!(matches!(
Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None),
Err(NewPromptError {
kind: NewPromptErrorKind::CannotParseTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_missing_doc() {
assert!(matches!(
Prompt::new("{{title}}: {{overview}}".into(), None, None),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
#[test]
fn template_nested_doc() {
Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap();
}
#[test]
fn template_fields() {
Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap();
}
#[test]
fn template_fields_ok() {
Prompt::new(
"{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(),
None,
None,
)
.unwrap();
}
#[test]
fn template_fields_invalid() {
assert!(matches!(
// intentionally garbled field
Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None),
Err(NewPromptError {
kind: NewPromptErrorKind::InvalidFieldsInTemplate(_),
fault: FaultSource::User
})
));
}
}

View File

@ -1,7 +1,7 @@
use liquid::model::{ use liquid::model::{
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
}; };
use liquid::{ObjectView, ValueView}; use liquid::{Object, ObjectView, ValueView};
#[derive(Debug)] #[derive(Debug)]
pub struct TemplateChecker; pub struct TemplateChecker;
@ -31,11 +31,11 @@ impl ObjectView for DummyField {
} }
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::empty()) Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter())
} }
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> { fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
Box::new(std::iter::empty()) Box::new(self.keys().zip(self.values()))
} }
fn contains_key(&self, index: &str) -> bool { fn contains_key(&self, index: &str) -> bool {
@ -69,7 +69,12 @@ impl ValueView for DummyField {
} }
fn query_state(&self, state: State) -> bool { fn query_state(&self, state: State) -> bool {
DUMMY_VALUE.query_state(state) match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
} }
fn to_kstr(&self) -> KStringCow<'_> { fn to_kstr(&self) -> KStringCow<'_> {
@ -77,7 +82,10 @@ impl ValueView for DummyField {
} }
fn to_value(&self) -> LiquidValue { fn to_value(&self) -> LiquidValue {
LiquidValue::Nil let mut this = Object::new();
this.insert("name".into(), LiquidValue::Nil);
this.insert("value".into(), LiquidValue::Nil);
LiquidValue::Object(this)
} }
fn as_object(&self) -> Option<&dyn ObjectView> { fn as_object(&self) -> Option<&dyn ObjectView> {
@ -103,7 +111,12 @@ impl ValueView for DummyFields {
} }
fn query_state(&self, state: State) -> bool { fn query_state(&self, state: State) -> bool {
DUMMY_VALUE.query_state(state) match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
} }
fn to_kstr(&self) -> KStringCow<'_> { fn to_kstr(&self) -> KStringCow<'_> {
@ -111,7 +124,7 @@ impl ValueView for DummyFields {
} }
fn to_value(&self) -> LiquidValue { fn to_value(&self) -> LiquidValue {
LiquidValue::Nil LiquidValue::Array(vec![DummyField.to_value()])
} }
fn as_array(&self) -> Option<&dyn ArrayView> { fn as_array(&self) -> Option<&dyn ArrayView> {
@ -125,15 +138,15 @@ impl ArrayView for DummyFields {
} }
fn size(&self) -> i64 { fn size(&self) -> i64 {
i64::MAX u16::MAX as i64
} }
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> { fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
Box::new(std::iter::empty()) Box::new(std::iter::once(DummyField.as_value()))
} }
fn contains_key(&self, _index: i64) -> bool { fn contains_key(&self, index: i64) -> bool {
true index < self.size()
} }
fn get(&self, _index: i64) -> Option<&dyn ValueView> { fn get(&self, _index: i64) -> Option<&dyn ValueView> {
@ -167,7 +180,8 @@ impl ObjectView for DummyDoc {
} }
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
Some(DUMMY_VALUE.as_view()) // Recursively sends itself
Some(self)
} }
} }
@ -189,7 +203,12 @@ impl ValueView for DummyDoc {
} }
fn query_state(&self, state: State) -> bool { fn query_state(&self, state: State) -> bool {
DUMMY_VALUE.query_state(state) match state {
State::Truthy => true,
State::DefaultValue => false,
State::Empty => false,
State::Blank => false,
}
} }
fn to_kstr(&self) -> KStringCow<'_> { fn to_kstr(&self) -> KStringCow<'_> {

View File

@ -516,7 +516,7 @@ pub fn execute_vector_search(
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
/// FIXME: input universe = universe & documents_with_vectors // FIXME: input universe = universe & documents_with_vectors
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe // for now if we're computing embeddings for ALL documents, we can assume that this is just universe
let ranking_rules = get_ranking_rules_for_vector( let ranking_rules = get_ranking_rules_for_vector(
ctx, ctx,

View File

@ -71,8 +71,8 @@ impl VectorStateDelta {
pub fn extract_vector_points<R: io::Read + io::Seek>( pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>, obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
field_id_map: FieldsIdsMap, field_id_map: &FieldsIdsMap,
prompt: Option<&Prompt>, prompt: &Prompt,
) -> Result<ExtractedVectorPoints> { ) -> Result<ExtractedVectorPoints> {
puffin::profile_function!(); puffin::profile_function!();
@ -142,14 +142,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some()); .any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept { if document_is_kept {
// becomes autogenerated // becomes autogenerated
match prompt { VectorStateDelta::NowGenerated(prompt.render(
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( obkv,
obkv, DelAdd::Addition,
DelAdd::Addition, field_id_map,
&field_id_map, )?)
)?),
None => VectorStateDelta::NowRemoved,
}
} else { } else {
VectorStateDelta::NowRemoved VectorStateDelta::NowRemoved
} }
@ -162,26 +159,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some()); .any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept { if document_is_kept {
match prompt { // Don't give up if the old prompt was failing
Some(prompt) => { let old_prompt =
// Don't give up if the old prompt was failing prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let old_prompt = prompt let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
.render(obkv, DelAdd::Deletion, &field_id_map) if old_prompt != new_prompt {
.unwrap_or_default(); log::trace!(
let new_prompt = "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
prompt.render(obkv, DelAdd::Addition, &field_id_map)?; );
if old_prompt != new_prompt { VectorStateDelta::NowGenerated(new_prompt)
log::trace!( } else {
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" log::trace!("⏭️ Prompt unmodified, skipping");
); VectorStateDelta::NoChange
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
// We no longer have a prompt, so we need to remove any existing vector
None => VectorStateDelta::NowRemoved,
} }
} else { } else {
VectorStateDelta::NowRemoved VectorStateDelta::NowRemoved
@ -196,24 +185,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some()); .any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept { if document_is_kept {
match prompt { // Don't give up if the old prompt was failing
Some(prompt) => { let old_prompt =
// Don't give up if the old prompt was failing prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let old_prompt = prompt let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
.render(obkv, DelAdd::Deletion, &field_id_map) if old_prompt != new_prompt {
.unwrap_or_default(); log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; VectorStateDelta::NowGenerated(new_prompt)
if old_prompt != new_prompt { } else {
log::trace!( log::trace!("⏭️ Prompt unmodified, skipping");
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" VectorStateDelta::NoChange
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
None => VectorStateDelta::NowRemoved,
} }
} else { } else {
VectorStateDelta::NowRemoved VectorStateDelta::NowRemoved
@ -322,7 +303,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>, prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters, indexer: GrenadParameters,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> { ) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
@ -341,8 +322,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let mut chunks_ids = Vec::with_capacity(n_chunks); let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?; let mut cursor = prompt_reader.into_cursor()?;
let mut expected_dimension = None;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string // SAFETY: precondition, the grenad value was saved from a string
@ -367,7 +346,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
) )
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids for (docid, embeddings) in chunks_ids
@ -376,7 +354,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{ {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
} }
chunks_ids.clear(); chunks_ids.clear();
} }
@ -387,7 +364,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let chunked_embeds = rt let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) .block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids for (docid, embeddings) in chunks_ids
.iter() .iter()
@ -395,7 +371,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{ {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
} }
} }
@ -403,14 +378,12 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let embeds = rt let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk))) .block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?; .map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
} }
} }
Ok((writer_into_reader(state_writer)?, expected_dimension)) writer_into_reader(state_writer)
} }

View File

@ -292,43 +292,42 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone(); let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || { rayon::spawn(move || {
let (embedder, prompt) = embedders.get("default").cloned().unzip(); for (name, (embedder, prompt)) in embedders {
let result = let result = extract_vector_points(
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); documents_chunk_cloned.clone(),
match result { indexer,
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { &field_id_map,
/// FIXME: support multiple embedders &prompt,
let results = embedder.and_then(|embedder| { );
match extract_embeddings(prompts, indexer, embedder.clone()) { match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results), Ok(results) => Some(results),
Err(error) => { Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error)); let _ = lmdb_writer_sx_cloned.send(Err(error));
None None
} }
} };
});
let (embeddings, expected_dimension) = results.unzip(); if !(remove_vectors.is_empty()
let expected_dimension = expected_dimension.flatten(); && manual_vectors.is_empty()
if !(remove_vectors.is_empty() && embeddings.as_ref().map_or(true, |e| e.is_empty()))
&& manual_vectors.is_empty() {
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
/// FIXME FIXME FIXME
if expected_dimension.is_some() {
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors, remove_vectors,
embeddings, embeddings,
/// FIXME: compute an expected dimension from the manual vectors if any expected_dimension: embedder.dimensions(),
expected_dimension: expected_dimension.unwrap(),
manual_vectors, manual_vectors,
embedder_name: name,
})); }));
} }
} }
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
} }
Err(error) => { }
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
};
}); });
// TODO: create a custom internal error // TODO: create a custom internal error

View File

@ -435,7 +435,7 @@ where
let mut word_docids = None; let mut word_docids = None;
let mut exact_word_docids = None; let mut exact_word_docids = None;
let mut dimension = None; let mut dimension = HashMap::new();
for result in lmdb_writer_rx { for result in lmdb_writer_rx {
if (self.should_abort)() { if (self.should_abort)() {
@ -471,13 +471,15 @@ where
remove_vectors, remove_vectors,
embeddings, embeddings,
manual_vectors, manual_vectors,
embedder_name,
} => { } => {
dimension = Some(expected_dimension); dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints { TypedChunk::VectorPoints {
remove_vectors, remove_vectors,
embeddings, embeddings,
expected_dimension, expected_dimension,
manual_vectors, manual_vectors,
embedder_name,
} }
} }
otherwise => otherwise, otherwise => otherwise,
@ -513,14 +515,22 @@ where
self.index.put_primary_key(self.wtxn, &primary_key)?; self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?;
if let Some(dimension) = dimension { for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn; let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy; let vector_arroy = self.index.vector_arroy;
/// FIXME: unwrap
let embedder_index =
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
pool.install(|| { pool.install(|| {
/// FIXME: do for each embedder let writer_index = (embedder_index as u16) << 8;
let mut rng = rand::rngs::StdRng::from_entropy(); let mut rng = rand::rngs::StdRng::from_entropy();
for k in 0..=u8::MAX { for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?; let writer = arroy::Writer::prepare(
wtxn,
vector_arroy,
writer_index | (k as u16),
dimension,
)?;
if writer.is_empty(wtxn)? { if writer.is_empty(wtxn)? {
break; break;
} }

View File

@ -47,6 +47,7 @@ pub(crate) enum TypedChunk {
embeddings: Option<grenad::Reader<BufReader<File>>>, embeddings: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize, expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>, manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
}, },
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
} }
@ -100,8 +101,8 @@ impl TypedChunk {
TypedChunk::GeoPoints(grenad) => { TypedChunk::GeoPoints(grenad) => {
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
} }
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension) format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
} }
TypedChunk::ScriptLanguageDocids(sl_map) => { TypedChunk::ScriptLanguageDocids(sl_map) => {
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index(
manual_vectors, manual_vectors,
embeddings, embeddings,
expected_dimension, expected_dimension,
embedder_name,
} => { } => {
/// FIXME: allow customizing distance /// FIXME: unwrap
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX) let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map(|k| { .map(|k| {
/// FIXME: allow customizing index and then do index << 8 + k arroy::Writer::prepare(
arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension) wtxn,
index.vector_arroy,
writer_index | (k as u16),
expected_dimension,
)
}) })
.collect(); .collect();
let writers = writers?; let writers = writers?;
@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index(
} }
} }
log::debug!("There are 🤷‍♀️ entries in the arroy so far"); log::debug!("Finished vector chunk for {}", embedder_name);
} }
TypedChunk::ScriptLanguageDocids(sl_map) => { TypedChunk::ScriptLanguageDocids(sl_map) => {
for (key, (deletion, addition)) in sl_map { for (key, (deletion, addition)) in sl_map {

View File

@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let embedder = Arc::new( let embedder = Arc::new(
Embedder::new(embedder_options.clone()) Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from) .map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?, .map_err(crate::Error::from)?,
); );
Ok((name, (embedder, prompt))) Ok((name, (embedder, prompt)))
@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Setting::NotSet => Some((name, EmbeddingSettings::default().into())), Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
}) })
.collect(); .collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
embedder_name,
&index
.try_into()
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
)?;
}
if new_configs.is_empty() { if new_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?; self.index.delete_embedding_configs(self.wtxn)?;
} else { } else {
@ -1062,7 +1074,7 @@ fn validate_prompt(
match new { match new {
Setting::Set(EmbeddingSettings { Setting::Set(EmbeddingSettings {
embedder_options, embedder_options,
prompt: document_template:
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
}) => { }) => {
// validate // validate
@ -1072,7 +1084,7 @@ fn validate_prompt(
Ok(Setting::Set(EmbeddingSettings { Ok(Setting::Set(EmbeddingSettings {
embedder_options, embedder_options,
prompt: Setting::Set(PromptSettings { document_template: Setting::Set(PromptSettings {
template: Setting::Set(template), template: Setting::Set(template),
strategy, strategy,
fallback, fallback,

View File

@ -65,6 +65,8 @@ pub enum EmbedErrorKind {
OpenAiTooManyTokens(OpenAiError), OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")] #[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16), OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String),
} }
impl EmbedError { impl EmbedError {
@ -111,6 +113,10 @@ impl EmbedError {
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
} }
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@ -170,6 +176,13 @@ impl NewEmbedderError {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
} }
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
} }
@ -219,6 +232,8 @@ pub enum NewEmbedderErrorKind {
NewApiFail(ApiError), NewApiFail(ApiError),
#[error("fetching file from HG_HUB failed: {0}")] #[error("fetching file from HG_HUB failed: {0}")]
ApiGet(ApiError), ApiGet(ApiError),
#[error("could not determine model dimensions: test embedding failed with {0}")]
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")] #[error("loading model failed: {0}")]
LoadModel(candle_core::Error), LoadModel(candle_core::Error),
// openai // openai

View File

@ -62,6 +62,7 @@ pub struct Embedder {
model: BertModel, model: BertModel,
tokenizer: Tokenizer, tokenizer: Tokenizer,
options: EmbedderOptions, options: EmbedderOptions,
dimensions: usize,
} }
impl std::fmt::Debug for Embedder { impl std::fmt::Debug for Embedder {
@ -126,10 +127,17 @@ impl Embedder {
tokenizer.with_padding(Some(pp)); tokenizer.with_padding(Some(pp));
} }
Ok(Self { model, tokenizer, options }) let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
} }
pub async fn embed( pub fn embed(
&self, &self,
mut texts: Vec<String>, mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
@ -170,12 +178,11 @@ impl Embedder {
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
} }
pub async fn embed_chunks( pub fn embed_chunks(
&self, &self,
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
.await
} }
pub fn chunk_count_hint(&self) -> usize { pub fn chunk_count_hint(&self) -> usize {
@ -185,6 +192,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
} }
pub fn dimensions(&self) -> usize {
self.dimensions
}
} }
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> { fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {

View File

@ -3,6 +3,7 @@ use crate::prompt::PromptData;
pub mod error; pub mod error;
pub mod hf; pub mod hf;
pub mod manual;
pub mod openai; pub mod openai;
pub mod settings; pub mod settings;
@ -67,6 +68,7 @@ impl<F> Embeddings<F> {
pub enum Embedder { pub enum Embedder {
HuggingFace(hf::Embedder), HuggingFace(hf::Embedder),
OpenAi(openai::Embedder), OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
} }
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
@ -80,6 +82,7 @@ pub struct EmbeddingConfig {
pub enum EmbedderOptions { pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions), HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions), OpenAi(openai::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
} }
impl Default for EmbedderOptions { impl Default for EmbedderOptions {
@ -93,7 +96,7 @@ impl EmbedderOptions {
Self::HuggingFace(hf::EmbedderOptions::new()) Self::HuggingFace(hf::EmbedderOptions::new())
} }
pub fn openai(api_key: String) -> Self { pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
} }
} }
@ -103,6 +106,9 @@ impl Embedder {
Ok(match options { Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
}) })
} }
@ -111,8 +117,9 @@ impl Embedder {
texts: Vec<String>, texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> { ) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts).await, Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await, Embedder::OpenAi(embedder) => embedder.embed(texts).await,
Embedder::UserProvided(embedder) => embedder.embed(texts),
} }
} }
@ -121,8 +128,9 @@ impl Embedder {
text_chunks: Vec<Vec<String>>, text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
} }
} }
@ -130,6 +138,7 @@ impl Embedder {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
} }
} }
@ -137,6 +146,15 @@ impl Embedder {
match self { match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
} }
} }
} }

View File

@ -15,7 +15,7 @@ pub struct Embedder {
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub api_key: String, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
} }
@ -68,11 +68,11 @@ impl EmbeddingModel {
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: String) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() } Self { api_key, embedding_model: Default::default() }
} }
pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model } Self { api_key, embedding_model }
} }
} }
@ -80,9 +80,14 @@ impl EmbedderOptions {
impl Embedder { impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new(); let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert( headers.insert(
reqwest::header::AUTHORIZATION, reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?, .map_err(NewEmbedderError::openai_invalid_api_key_format)?,
); );
headers.insert( headers.insert(
@ -315,6 +320,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
10 10
} }
pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions()
}
} }
// retrying in case of failure // retrying in case of failure
@ -414,3 +423,9 @@ struct OpenAiEmbedding {
// object: String, // object: String,
// index: usize, // index: usize,
} }
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}

View File

@ -15,14 +15,14 @@ pub struct EmbeddingSettings {
pub embedder_options: Setting<EmbedderSettings>, pub embedder_options: Setting<EmbedderSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub prompt: Setting<PromptSettings>, pub document_template: Setting<PromptSettings>,
} }
impl EmbeddingSettings { impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) { pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { embedder_options, prompt } = new; let EmbeddingSettings { embedder_options, document_template: prompt } = new;
self.embedder_options.apply(embedder_options); self.embedder_options.apply(embedder_options);
self.prompt.apply(prompt); self.document_template.apply(prompt);
} }
} }
@ -30,7 +30,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self { fn from(value: EmbeddingConfig) -> Self {
Self { Self {
embedder_options: Setting::Set(value.embedder_options.into()), embedder_options: Setting::Set(value.embedder_options.into()),
prompt: Setting::Set(value.prompt.into()), document_template: Setting::Set(value.prompt.into()),
} }
} }
} }
@ -38,7 +38,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
impl From<EmbeddingSettings> for EmbeddingConfig { impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self { fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default(); let mut this = Self::default();
let EmbeddingSettings { embedder_options, prompt } = value; let EmbeddingSettings { embedder_options, document_template: prompt } = value;
if let Some(embedder_options) = embedder_options.set() { if let Some(embedder_options) = embedder_options.set() {
this.embedder_options = embedder_options.into(); this.embedder_options = embedder_options.into();
} }
@ -105,6 +105,7 @@ impl From<PromptSettings> for PromptData {
pub enum EmbedderSettings { pub enum EmbedderSettings {
HuggingFace(Setting<HfEmbedderSettings>), HuggingFace(Setting<HfEmbedderSettings>),
OpenAi(Setting<OpenAiEmbedderSettings>), OpenAi(Setting<OpenAiEmbedderSettings>),
UserProvided(UserProvidedSettings),
} }
impl<E> Deserr<E> for EmbedderSettings impl<E> Deserr<E> for EmbedderSettings
@ -145,11 +146,17 @@ where
location.push_key(&k), location.push_key(&k),
)?, )?,
))), ))),
"userProvided" => Ok(EmbedderSettings::UserProvided(
UserProvidedSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
)),
other => Err(deserr::take_cf_content(E::error::<V>( other => Err(deserr::take_cf_content(E::error::<V>(
None, None,
deserr::ErrorKind::UnknownKey { deserr::ErrorKind::UnknownKey {
key: other, key: other,
accepted: &["huggingFace", "openAi"], accepted: &["huggingFace", "openAi", "userProvided"],
}, },
location, location,
))), ))),
@ -182,6 +189,9 @@ impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
crate::vector::EmbedderOptions::OpenAi(openai) => { crate::vector::EmbedderOptions::OpenAi(openai) => {
Self::OpenAi(Setting::Set(openai.into())) Self::OpenAi(Setting::Set(openai.into()))
} }
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
} }
} }
} }
@ -192,9 +202,12 @@ impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
EmbedderSettings::OpenAi(_setting) => Self::OpenAi( EmbedderSettings::OpenAi(_setting) => {
crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
), }
EmbedderSettings::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
} }
} }
} }
@ -286,7 +299,7 @@ impl OpenAiEmbedderSettings {
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings { impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
fn from(value: crate::vector::openai::EmbedderOptions) -> Self { fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
Self { Self {
api_key: Setting::Set(value.api_key), api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
embedding_model: Setting::Set(value.embedding_model), embedding_model: Setting::Set(value.embedding_model),
} }
} }
@ -295,14 +308,25 @@ impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions { impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
fn from(value: OpenAiEmbedderSettings) -> Self { fn from(value: OpenAiEmbedderSettings) -> Self {
let OpenAiEmbedderSettings { api_key, embedding_model } = value; let OpenAiEmbedderSettings { api_key, embedding_model } = value;
Self { Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
api_key: api_key.set().unwrap_or_else(infer_api_key),
embedding_model: embedding_model.set().unwrap_or_default(),
}
} }
} }
fn infer_api_key() -> String { #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
/// FIXME: get key from instance options? #[serde(deny_unknown_fields, rename_all = "camelCase")]
std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() #[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct UserProvidedSettings {
pub dimensions: usize,
}
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
fn from(value: UserProvidedSettings) -> Self {
Self { dimensions: value.dimensions }
}
}
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
Self { dimensions: value.dimensions }
}
} }