diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index 65d257ea0..b9b360fa4 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -52,7 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::update::IndexerConfig; -use meilisearch_types::milli::vector::{Embedder, EmbedderOptions}; +use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use puffin::FrameView; @@ -1339,11 +1339,10 @@ impl IndexScheduler { } // TODO: consider using a type alias or a struct embedder/template - #[allow(clippy::type_complexity)] pub fn embedders( &self, embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, - ) -> Result, Arc)>> { + ) -> Result { let res: Result<_> = embedding_configs .into_iter() .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { @@ -1370,7 +1369,7 @@ impl IndexScheduler { Ok((name, (embedder, prompt))) }) .collect(); - res + res.map(EmbeddingConfigs::new) } /// Blocks the thread until the test handle asks to progress to/through this breakpoint. diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index c057d4809..7a9a14687 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -238,22 +238,28 @@ pub async fn embed( match query.vector.take() { Some(VectorQuery::String(prompt)) => { let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedder = index_scheduler.embedders(embedder_configs)?; + let embedders = index_scheduler.embedders(embedder_configs)?; let embedder_name = if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = &query.hybrid { - embedder + Some(embedder) } else { - "default" + None }; - let embeddings = embedder - .get(embedder_name) - .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + let embedder = if let Some(embedder_name) = embedder_name { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) .map_err(milli::Error::from)? - .0 + .0; + let embeddings = embedder .embed(vec![prompt]) .await .map_err(milli::vector::Error::from) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index d496da1a3..53f6140fb 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -398,6 +398,10 @@ fn prepare_search<'t>( features.check_vector("Passing `vector` as a query parameter")?; } + if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { + search.embedder_name(embedder); + } + // compute the offset on the limit depending on the pagination mode. let (offset, limit) = if is_finite_pagination { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); diff --git a/milli/src/index.rs b/milli/src/index.rs index 05babf410..6ad39dcb1 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -1499,6 +1499,14 @@ impl Index { .get(rtxn, main_key::EMBEDDING_CONFIGS)? .unwrap_or_default()) } + + pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result { + let configs = self.embedding_configs(rtxn)?; + Ok(match configs.as_slice() { + [(ref first_name, _)] => first_name.clone(), + _ => "default".to_owned(), + }) + } } #[cfg(test)] diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 02c518126..cbec20c65 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -218,6 +218,8 @@ impl<'a> Search<'a> { exhaustive_number_hits: self.exhaustive_number_hits, rtxn: self.rtxn, index: self.index, + distribution_shift: self.distribution_shift, + embedder_name: self.embedder_name.clone(), }; let vector_query = search.vector.take(); @@ -265,6 +267,15 @@ impl<'a> Search<'a> { vector: &[f32], keyword_results: &SearchResult, ) -> Result { + let embedder_name; + let embedder_name = match &self.embedder_name { + Some(embedder_name) => embedder_name, + None => { + embedder_name = self.index.default_embedding_name(self.rtxn)?; + &embedder_name + } + }; + let mut ctx = SearchContext::new(self.index, self.rtxn); if let Some(searchable_attributes) = self.searchable_attributes { @@ -282,6 +293,8 @@ impl<'a> Search<'a> { self.geo_strategy, 0, self.limit + self.offset, + self.distribution_shift, + embedder_name, ) } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 8b541ffcd..04a6005e3 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -17,6 +17,7 @@ use self::new::{execute_vector_search, PartialSearchResult}; use crate::error::UserError; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; +use crate::vector::DistributionShift; use crate::{ execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, @@ -51,6 +52,8 @@ pub struct Search<'a> { exhaustive_number_hits: bool, rtxn: &'a heed::RoTxn<'a>, index: &'a Index, + distribution_shift: Option, + embedder_name: Option, } #[derive(Debug, Clone, PartialEq)] @@ -117,6 +120,8 @@ impl<'a> Search<'a> { words_limit: 10, rtxn, index, + distribution_shift: None, + embedder_name: None, } } @@ -183,7 +188,29 @@ impl<'a> Search<'a> { self } + pub fn distribution_shift( + &mut self, + distribution_shift: Option, + ) -> &mut Search<'a> { + self.distribution_shift = distribution_shift; + self + } + + pub fn embedder_name(&mut self, embedder_name: impl Into) -> &mut Search<'a> { + self.embedder_name = Some(embedder_name.into()); + self + } + pub fn execute(&self) -> Result { + let embedder_name; + let embedder_name = match &self.embedder_name { + Some(embedder_name) => embedder_name, + None => { + embedder_name = self.index.default_embedding_name(self.rtxn)?; + &embedder_name + } + }; + let mut ctx = SearchContext::new(self.index, self.rtxn); if let Some(searchable_attributes) = self.searchable_attributes { @@ -202,6 +229,8 @@ impl<'a> Search<'a> { self.geo_strategy, self.offset, self.limit, + self.distribution_shift, + embedder_name, )?, None => execute_search( &mut ctx, @@ -247,6 +276,8 @@ impl fmt::Debug for Search<'_> { exhaustive_number_hits, rtxn: _, index: _, + distribution_shift, + embedder_name, } = self; f.debug_struct("Search") .field("query", query) @@ -260,6 +291,8 @@ impl fmt::Debug for Search<'_> { .field("scoring_strategy", scoring_strategy) .field("exhaustive_number_hits", exhaustive_number_hits) .field("words_limit", words_limit) + .field("distribution_shift", distribution_shift) + .field("embedder_name", embedder_name) .finish() } } diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index bc7f6fb08..405b9747d 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -266,6 +266,7 @@ fn get_ranking_rules_for_vector<'ctx>( limit_plus_offset: usize, target: &[f32], distribution_shift: Option, + embedder_name: &str, ) -> Result>> { // query graph search @@ -292,6 +293,7 @@ fn get_ranking_rules_for_vector<'ctx>( vector_candidates, limit_plus_offset, distribution_shift, + embedder_name, )?; ranking_rules.push(Box::new(vector_sort)); vector = true; @@ -513,6 +515,8 @@ pub fn execute_vector_search( geo_strategy: geo_sort::Strategy, from: usize, length: usize, + distribution_shift: Option, + embedder_name: &str, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -524,7 +528,8 @@ pub fn execute_vector_search( geo_strategy, from + length, vector, - None, + distribution_shift, + embedder_name, )?; let mut placeholder_search_logger = logger::DefaultSearchLogger; diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 38fcfde48..6a37ceb7d 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -15,16 +15,21 @@ pub struct VectorSort { cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, limit: usize, distribution_shift: Option, + embedder_index: u8, } impl VectorSort { pub fn new( - _ctx: &SearchContext, + ctx: &SearchContext, target: Vec, vector_candidates: RoaringBitmap, limit: usize, distribution_shift: Option, + embedder_name: &str, ) -> Result { + /// FIXME: unwrap + let embedder_index = ctx.index.embedder_category_id.get(ctx.txn, embedder_name)?.unwrap(); + Ok(Self { query: None, target, @@ -32,6 +37,7 @@ impl VectorSort { cached_sorted_docids: Default::default(), limit, distribution_shift, + embedder_index, }) } @@ -40,9 +46,10 @@ impl VectorSort { ctx: &mut SearchContext<'_>, vector_candidates: &RoaringBitmap, ) -> Result<()> { + let writer_index = (self.embedder_index as u16) << 8; let readers: std::result::Result, _> = (0..=u8::MAX) .map_while(|k| { - arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy) + arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy) .map(Some) .or_else(|e| match e { arroy::Error::MissingMetadata => Ok(None), diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index a852b035b..1d06849de 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -9,10 +9,9 @@ mod extract_word_docids; mod extract_word_pair_proximity_docids; mod extract_word_position_docids; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fs::File; use std::io::BufReader; -use std::sync::Arc; use crossbeam_channel::Sender; use log::debug; @@ -35,9 +34,8 @@ use super::helpers::{ MergeFn, MergeableReader, }; use super::{helpers, TypedChunk}; -use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; -use crate::vector::Embedder; +use crate::vector::EmbeddingConfigs; use crate::{FieldId, FieldsIdsMap, Result}; /// Extract data for each databases from obkv documents in parallel. @@ -59,7 +57,7 @@ pub(crate) fn data_from_obkv_documents( max_positions_per_attributes: Option, exact_attributes: HashSet, proximity_precision: ProximityPrecision, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, ) -> Result<()> { puffin::profile_function!(); @@ -284,7 +282,7 @@ fn send_original_documents_data( indexer: GrenadParameters, lmdb_writer_sx: Sender>, field_id_map: FieldsIdsMap, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 075dcd184..efc6b22ff 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -9,7 +9,6 @@ use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::NonZeroU32; use std::result::Result as StdResult; -use std::sync::Arc; use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; @@ -34,12 +33,11 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; -use crate::prompt::Prompt; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; -use crate::vector::Embedder; +use crate::vector::EmbeddingConfigs; use crate::{CboRoaringBitmapCodec, Index, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -82,7 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { should_abort: FA, added_documents: u64, deleted_documents: u64, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, } #[derive(Default, Debug, Clone)] @@ -173,10 +171,7 @@ where Ok((self, Ok(indexed_documents))) } - pub fn with_embedders( - mut self, - embedders: HashMap, Arc)>, - ) -> Self { + pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self { self.embedders = embedders; self } diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 1149dbce5..e9f345e42 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -14,12 +14,11 @@ use super::IndexerConfig; use crate::criterion::Criterion; use crate::error::UserError; use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS}; -use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::vector::settings::{EmbeddingSettings, PromptSettings}; -use crate::vector::{Embedder, EmbeddingConfig}; +use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -422,7 +421,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { fn embedders( &self, embedding_configs: Vec<(String, EmbeddingConfig)>, - ) -> Result, Arc)>> { + ) -> Result { let res: Result<_> = embedding_configs .into_iter() .map(|(name, EmbeddingConfig { embedder_options, prompt })| { @@ -436,7 +435,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok((name, (embedder, prompt))) }) .collect(); - res + res.map(EmbeddingConfigs::new) } fn update_displayed(&mut self) -> Result { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index fa39c20a2..df5750e77 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,5 +1,8 @@ +use std::collections::HashMap; +use std::sync::Arc; + use self::error::{EmbedError, NewEmbedderError}; -use crate::prompt::PromptData; +use crate::prompt::{Prompt, PromptData}; pub mod error; pub mod hf; @@ -82,6 +85,44 @@ pub struct EmbeddingConfig { // TODO: add metrics and anything needed } +#[derive(Clone, Default)] +pub struct EmbeddingConfigs(HashMap, Arc)>); + +impl EmbeddingConfigs { + pub fn new(data: HashMap, Arc)>) -> Self { + Self(data) + } + + pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { + self.0.get(name).cloned() + } + + pub fn get_default(&self) -> Option<(Arc, Arc)> { + self.get_default_embedder_name().and_then(|default| self.get(&default)) + } + + pub fn get_default_embedder_name(&self) -> Option { + let mut it = self.0.keys(); + let first_name = it.next(); + let second_name = it.next(); + match (first_name, second_name) { + (None, _) => None, + (Some(first), None) => Some(first.to_owned()), + (Some(_), Some(_)) => Some("default".to_owned()), + } + } +} + +impl IntoIterator for EmbeddingConfigs { + type Item = (String, (Arc, Arc)); + + type IntoIter = std::collections::hash_map::IntoIter, Arc)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions),