813 lines
27 KiB
Rust
Raw Normal View History

use std::collections::HashMap;
use std::sync::Arc;
use arroy::distances::{Angular, BinaryQuantizedAngular};
use arroy::ItemId;
2024-03-27 11:50:22 +01:00
use deserr::{DeserializeError, Deserr};
use heed::{RoTxn, RwTxn, Unspecified};
2024-03-25 10:05:38 +01:00
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
2024-03-25 10:05:38 +01:00
use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
use crate::ThreadPoolNoAbort;
pub mod error;
pub mod hf;
2024-07-16 13:37:26 +02:00
pub mod json_template;
pub mod manual;
pub mod openai;
2024-05-14 11:22:16 +02:00
pub mod parsed_vectors;
pub mod settings;
pub mod ollama;
2024-03-14 14:44:43 +01:00
pub mod rest;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
pub const REQUEST_PARALLELISM: usize = 40;
pub struct ArroyWrapper {
quantized: bool,
2024-09-24 10:36:28 +02:00
embedder_index: u8,
database: arroy::Database<Unspecified>,
}
impl ArroyWrapper {
2024-09-24 10:36:28 +02:00
pub fn new(
database: arroy::Database<Unspecified>,
embedder_index: u8,
quantized: bool,
) -> Self {
Self { database, embedder_index, quantized }
}
pub fn embedder_index(&self) -> u8 {
2024-09-24 10:36:28 +02:00
self.embedder_index
}
2024-09-23 18:56:15 +02:00
fn readers<'a, D: arroy::Distance>(
&'a self,
rtxn: &'a RoTxn<'a>,
db: arroy::Database<D>,
) -> impl Iterator<Item = Result<arroy::Reader<D>, arroy::Error>> + 'a {
2024-09-24 10:36:28 +02:00
arroy_db_range_for_embedder(self.embedder_index).map_while(move |index| {
2024-09-23 18:56:15 +02:00
match arroy::Reader::open(rtxn, index, db) {
Ok(reader) => Some(Ok(reader)),
Err(arroy::Error::MissingMetadata(_)) => None,
Err(e) => Some(Err(e)),
}
})
}
pub fn dimensions(&self, rtxn: &RoTxn) -> Result<usize, arroy::Error> {
2024-09-24 10:36:28 +02:00
let first_id = arroy_db_range_for_embedder(self.embedder_index).next().unwrap();
if self.quantized {
2024-09-19 17:42:52 +02:00
Ok(arroy::Reader::open(rtxn, first_id, self.quantized_db())?.dimensions())
} else {
2024-09-19 17:42:52 +02:00
Ok(arroy::Reader::open(rtxn, first_id, self.angular_db())?.dimensions())
}
}
2024-09-19 17:42:52 +02:00
pub fn quantize(&mut self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
if !self.quantized {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-19 17:42:52 +02:00
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
writer.prepare_changing_distance::<BinaryQuantizedAngular>(wtxn)?;
}
self.quantized = true;
}
Ok(())
}
2024-09-19 17:42:52 +02:00
// TODO: We can stop early when we find an empty DB
pub fn need_build(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-19 17:42:52 +02:00
let need_build = if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension).need_build(rtxn)
} else {
arroy::Writer::new(self.angular_db(), index, dimension).need_build(rtxn)
};
if need_build? {
return Ok(true);
}
}
2024-09-19 17:42:52 +02:00
Ok(false)
}
2024-09-19 17:42:52 +02:00
/// TODO: We should early exit when it doesn't need to be built
pub fn build<R: rand::Rng + rand::SeedableRng>(
&self,
wtxn: &mut RwTxn,
rng: &mut R,
dimension: usize,
) -> Result<(), arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-19 17:42:52 +02:00
if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension).build(wtxn, rng, None)?
} else {
arroy::Writer::new(self.angular_db(), index, dimension).build(wtxn, rng, None)?
}
}
2024-09-19 17:42:52 +02:00
Ok(())
}
2024-09-23 18:56:15 +02:00
/// Overwrite all the embeddings associated to the index and item id.
2024-09-23 15:15:26 +02:00
pub fn add_items(
&self,
wtxn: &mut RwTxn,
item_id: arroy::ItemId,
embeddings: &Embeddings<f32>,
) -> Result<(), arroy::Error> {
let dimension = embeddings.dimension();
2024-09-24 10:36:28 +02:00
for (index, vector) in
arroy_db_range_for_embedder(self.embedder_index).zip(embeddings.iter())
{
2024-09-23 15:15:26 +02:00
if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension)
.add_item(wtxn, item_id, vector)?
} else {
arroy::Writer::new(self.angular_db(), index, dimension)
.add_item(wtxn, item_id, vector)?
}
}
Ok(())
}
2024-09-23 18:56:15 +02:00
/// Add one document int for this index where we can find an empty spot.
pub fn add_item(
&self,
wtxn: &mut RwTxn,
item_id: arroy::ItemId,
vector: &[f32],
2024-09-23 18:56:15 +02:00
) -> Result<(), arroy::Error> {
if self.quantized {
self._add_item(wtxn, self.quantized_db(), item_id, vector)
} else {
self._add_item(wtxn, self.angular_db(), item_id, vector)
}
}
fn _add_item<D: arroy::Distance>(
&self,
wtxn: &mut RwTxn,
db: arroy::Database<D>,
item_id: arroy::ItemId,
vector: &[f32],
) -> Result<(), arroy::Error> {
2024-09-23 15:15:26 +02:00
let dimension = vector.len();
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 18:56:15 +02:00
let writer = arroy::Writer::new(db, index, dimension);
if !writer.contains_item(wtxn, item_id)? {
writer.add_item(wtxn, item_id, vector)?;
break;
2024-09-23 15:15:26 +02:00
}
}
2024-09-23 15:15:26 +02:00
Ok(())
}
2024-09-23 18:56:15 +02:00
/// Delete an item from the index. It **does not** take care of fixing the hole
/// made after deleting the item.
2024-09-23 15:15:26 +02:00
pub fn del_item_raw(
&self,
wtxn: &mut RwTxn,
dimension: usize,
item_id: arroy::ItemId,
) -> Result<bool, arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 15:15:26 +02:00
if self.quantized {
let writer = arroy::Writer::new(self.quantized_db(), index, dimension);
if writer.del_item(wtxn, item_id)? {
return Ok(true);
}
} else {
let writer = arroy::Writer::new(self.angular_db(), index, dimension);
if writer.del_item(wtxn, item_id)? {
return Ok(true);
}
}
}
2024-09-23 15:15:26 +02:00
Ok(false)
}
2024-09-23 18:56:15 +02:00
/// Delete one item.
2024-09-23 15:15:26 +02:00
pub fn del_item(
&self,
wtxn: &mut RwTxn,
2024-09-23 18:56:15 +02:00
item_id: arroy::ItemId,
vector: &[f32],
) -> Result<bool, arroy::Error> {
if self.quantized {
self._del_item(wtxn, self.quantized_db(), item_id, vector)
} else {
self._del_item(wtxn, self.angular_db(), item_id, vector)
}
}
fn _del_item<D: arroy::Distance>(
&self,
wtxn: &mut RwTxn,
db: arroy::Database<D>,
item_id: arroy::ItemId,
2024-09-23 15:15:26 +02:00
vector: &[f32],
) -> Result<bool, arroy::Error> {
let dimension = vector.len();
let mut deleted_index = None;
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 18:56:15 +02:00
let writer = arroy::Writer::new(db, index, dimension);
let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
// uses invariant: vectors are packed in the first writers.
break;
};
if candidate == vector {
writer.del_item(wtxn, item_id)?;
deleted_index = Some(index);
2024-09-23 15:15:26 +02:00
}
}
// 🥲 enforce invariant: vectors are packed in the first writers.
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
2024-09-24 10:36:28 +02:00
for index in
arroy_db_range_for_embedder(self.embedder_index).skip(deleted_index as usize)
{
2024-09-23 18:56:15 +02:00
let writer = arroy::Writer::new(db, index, dimension);
let Some(candidate) = writer.item_vector(wtxn, item_id)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));
2024-09-23 15:15:26 +02:00
}
if let Some((last_index, vector)) = last_index_with_a_vector {
2024-09-23 18:56:15 +02:00
// unwrap: computed the index from the list of writers
let writer = arroy::Writer::new(db, last_index, dimension);
writer.del_item(wtxn, item_id)?;
let writer = arroy::Writer::new(db, deleted_index, dimension);
writer.add_item(wtxn, item_id, &vector)?;
2024-09-23 15:15:26 +02:00
}
}
Ok(deleted_index.is_some())
}
pub fn clear(&self, wtxn: &mut RwTxn, dimension: usize) -> Result<(), arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 15:15:26 +02:00
if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension).clear(wtxn)?;
} else {
arroy::Writer::new(self.angular_db(), index, dimension).clear(wtxn)?;
}
}
2024-09-23 15:15:26 +02:00
Ok(())
}
pub fn is_empty(&self, rtxn: &RoTxn, dimension: usize) -> Result<bool, arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 15:15:26 +02:00
let empty = if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension).is_empty(rtxn)?
} else {
arroy::Writer::new(self.angular_db(), index, dimension).is_empty(rtxn)?
};
if !empty {
return Ok(false);
}
}
2024-09-23 15:15:26 +02:00
Ok(true)
}
pub fn contains_item(
&self,
rtxn: &RoTxn,
dimension: usize,
item: arroy::ItemId,
) -> Result<bool, arroy::Error> {
2024-09-24 10:36:28 +02:00
for index in arroy_db_range_for_embedder(self.embedder_index) {
2024-09-23 15:15:26 +02:00
let contains = if self.quantized {
arroy::Writer::new(self.quantized_db(), index, dimension)
.contains_item(rtxn, item)?
} else {
arroy::Writer::new(self.angular_db(), index, dimension).contains_item(rtxn, item)?
};
if contains {
return Ok(contains);
}
}
2024-09-23 15:15:26 +02:00
Ok(false)
}
pub fn nns_by_item(
&self,
rtxn: &RoTxn,
item: ItemId,
limit: usize,
filter: Option<&RoaringBitmap>,
2024-09-23 18:56:15 +02:00
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
if self.quantized {
self._nns_by_item(rtxn, self.quantized_db(), item, limit, filter)
} else {
self._nns_by_item(rtxn, self.angular_db(), item, limit, filter)
}
}
fn _nns_by_item<D: arroy::Distance>(
&self,
rtxn: &RoTxn,
db: arroy::Database<D>,
item: ItemId,
limit: usize,
filter: Option<&RoaringBitmap>,
2024-09-23 15:15:26 +02:00
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
let mut results = Vec::new();
2024-09-23 18:56:15 +02:00
for reader in self.readers(rtxn, db) {
let ret = reader?.nns_by_item(rtxn, item, limit, None, None, filter)?;
2024-09-23 15:15:26 +02:00
if let Some(mut ret) = ret {
results.append(&mut ret);
} else {
break;
}
}
2024-09-23 15:15:26 +02:00
results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
Ok(results)
}
pub fn nns_by_vector(
&self,
2024-09-23 18:56:15 +02:00
rtxn: &RoTxn,
vector: &[f32],
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
if self.quantized {
self._nns_by_vector(rtxn, self.quantized_db(), vector, limit, filter)
} else {
self._nns_by_vector(rtxn, self.angular_db(), vector, limit, filter)
}
}
fn _nns_by_vector<D: arroy::Distance>(
&self,
rtxn: &RoTxn,
db: arroy::Database<D>,
vector: &[f32],
limit: usize,
filter: Option<&RoaringBitmap>,
) -> Result<Vec<(ItemId, f32)>, arroy::Error> {
2024-09-23 15:15:26 +02:00
let mut results = Vec::new();
2024-09-23 18:56:15 +02:00
for reader in self.readers(rtxn, db) {
let mut ret = reader?.nns_by_vector(rtxn, vector, limit, None, None, filter)?;
2024-09-23 15:15:26 +02:00
results.append(&mut ret);
}
2024-09-23 15:15:26 +02:00
results.sort_unstable_by_key(|(_, distance)| OrderedFloat(*distance));
Ok(results)
}
2024-09-23 18:56:15 +02:00
pub fn item_vectors(&self, rtxn: &RoTxn, item_id: u32) -> Result<Vec<Vec<f32>>, arroy::Error> {
let mut vectors = Vec::new();
if self.quantized {
for reader in self.readers(rtxn, self.quantized_db()) {
if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
vectors.push(vec);
} else {
break;
}
}
} else {
for reader in self.readers(rtxn, self.angular_db()) {
if let Some(vec) = reader?.item_vector(rtxn, item_id)? {
vectors.push(vec);
} else {
break;
}
2024-09-23 15:15:26 +02:00
}
}
2024-09-23 18:56:15 +02:00
Ok(vectors)
}
fn angular_db(&self) -> arroy::Database<Angular> {
self.database.remap_data_type()
}
fn quantized_db(&self) -> arroy::Database<BinaryQuantizedAngular> {
self.database.remap_data_type()
}
}
2024-03-12 15:00:26 +01:00
/// One or multiple embeddings stored consecutively in a flat vector.
pub struct Embeddings<F> {
data: Vec<F>,
dimension: usize,
}
impl<F> Embeddings<F> {
2024-03-12 15:00:26 +01:00
/// Declares an empty vector of embeddings of the specified dimensions.
pub fn new(dimension: usize) -> Self {
Self { data: Default::default(), dimension }
}
2024-03-12 15:00:26 +01:00
/// Declares a vector of embeddings containing a single element.
///
/// The dimension is inferred from the length of the passed embedding.
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
Self { dimension: embedding.len(), data: embedding }
}
2024-03-12 15:00:26 +01:00
/// Declares a vector of embeddings from its components.
///
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
let mut this = Self::new(dimension);
this.append(data)?;
Ok(this)
}
2024-03-12 15:00:26 +01:00
/// Returns the number of embeddings in this vector of embeddings.
pub fn embedding_count(&self) -> usize {
self.data.len() / self.dimension
}
2024-03-12 15:00:26 +01:00
/// Dimension of a single embedding.
pub fn dimension(&self) -> usize {
self.dimension
}
2024-03-12 15:00:26 +01:00
/// Deconstructs self into the inner flat vector.
pub fn into_inner(self) -> Vec<F> {
self.data
}
2024-03-12 15:00:26 +01:00
/// A reference to the inner flat vector.
pub fn as_inner(&self) -> &[F] {
&self.data
}
2024-03-12 15:00:26 +01:00
/// Iterates over the embeddings contained in the flat vector.
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
self.data.as_slice().chunks_exact(self.dimension)
}
2024-03-12 15:00:26 +01:00
/// Push an embedding at the end of the embeddings.
///
/// If `embedding.len() != self.dimension`, then the push operation fails.
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
if embedding.len() != self.dimension {
return Err(embedding);
}
self.data.append(&mut embedding);
Ok(())
}
2024-03-12 15:00:26 +01:00
/// Append a flat vector of embeddings a the end of the embeddings.
///
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
if embeddings.len() % self.dimension != 0 {
return Err(embeddings);
}
self.data.append(&mut embeddings);
Ok(())
}
}
2024-03-12 15:00:26 +01:00
/// An embedder can be used to transform text into embeddings.
#[derive(Debug)]
pub enum Embedder {
2024-03-12 15:00:26 +01:00
/// An embedder based on running local models, fetched from the Hugging Face Hub.
HuggingFace(hf::Embedder),
2024-03-12 15:00:26 +01:00
/// An embedder based on making embedding queries against the OpenAI API.
OpenAi(openai::Embedder),
2024-03-12 15:00:26 +01:00
/// An embedder based on the user providing the embeddings in the documents and queries.
UserProvided(manual::Embedder),
2024-03-25 10:05:38 +01:00
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
Ollama(ollama::Embedder),
2024-03-25 10:05:38 +01:00
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
Rest(rest::Embedder),
}
2024-03-12 15:00:26 +01:00
/// Configuration for an embedder.
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
pub struct EmbeddingConfig {
2024-03-12 15:00:26 +01:00
/// Options of the embedder, specific to each kind of embedder
pub embedder_options: EmbedderOptions,
2024-03-12 15:00:26 +01:00
/// Document template
pub prompt: PromptData,
/// If this embedder is binary quantized
pub quantized: Option<bool>,
// TODO: add metrics and anything needed
}
impl EmbeddingConfig {
pub fn quantized(&self) -> bool {
self.quantized.unwrap_or_default()
}
}
2024-03-12 15:00:26 +01:00
/// Map of embedder configurations.
///
/// Each configuration is mapped to a name.
#[derive(Clone, Default)]
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>);
impl EmbeddingConfigs {
2024-03-12 15:00:26 +01:00
/// Create the map from its internal component.s
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)>) -> Self {
Self(data)
}
2024-03-12 15:00:26 +01:00
/// Get an embedder configuration and template from its name.
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>, bool)> {
self.0.get(name).cloned()
}
pub fn inner_as_ref(&self) -> &HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
&self.0
}
pub fn into_inner(self) -> HashMap<String, (Arc<Embedder>, Arc<Prompt>, bool)> {
2024-06-12 14:02:12 +02:00
self.0
}
}
impl IntoIterator for EmbeddingConfigs {
type Item = (String, (Arc<Embedder>, Arc<Prompt>, bool));
type IntoIter =
std::collections::hash_map::IntoIter<String, (Arc<Embedder>, Arc<Prompt>, bool)>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
2024-03-12 15:00:26 +01:00
/// Options of an embedder, specific to each kind of embedder.
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
2024-03-25 10:05:38 +01:00
Rest(rest::EmbedderOptions),
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
}
}
impl Embedder {
2024-03-12 15:00:26 +01:00
/// Spawns a new embedder built from its options.
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
2024-07-16 15:04:40 +02:00
EmbedderOptions::Rest(options) => {
Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
}
})
}
2024-03-12 15:00:26 +01:00
/// Embed one or multiple texts.
///
/// Each text can be embedded as one or multiple embeddings.
pub fn embed(
&self,
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts),
2024-03-14 11:14:31 +01:00
Embedder::OpenAi(embedder) => embedder.embed(texts),
Embedder::Ollama(embedder) => embedder.embed(texts),
Embedder::UserProvided(embedder) => embedder.embed(texts),
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.embed(texts),
}
}
2024-03-28 11:49:23 +01:00
pub fn embed_one(&self, text: String) -> std::result::Result<Embedding, EmbedError> {
let mut embeddings = self.embed(vec![text])?;
let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?;
Ok(if embeddings.iter().nth(1).is_some() {
tracing::warn!("Ignoring embeddings past the first one in long search query");
embeddings.iter().next().unwrap().to_vec()
} else {
embeddings.into_inner()
})
}
2024-03-12 15:00:26 +01:00
/// Embed multiple chunks of texts.
///
/// Each chunk is composed of one or multiple texts.
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &ThreadPoolNoAbort,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
}
}
2024-03-12 15:00:26 +01:00
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
}
}
2024-03-12 15:00:26 +01:00
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
}
}
2024-03-12 15:00:26 +01:00
/// Indicates the dimensions of a single embedding produced by the embedder.
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.dimensions(),
}
}
2024-03-12 15:00:26 +01:00
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
2024-03-27 11:50:22 +01:00
Embedder::UserProvided(embedder) => embedder.distribution(),
2024-03-25 10:05:38 +01:00
Embedder::Rest(embedder) => embedder.distribution(),
}
}
2024-09-02 12:58:09 +02:00
pub fn uses_document_template(&self) -> bool {
match self {
Embedder::HuggingFace(_)
| Embedder::OpenAi(_)
| Embedder::Ollama(_)
| Embedder::Rest(_) => true,
Embedder::UserProvided(_) => false,
}
}
}
2024-03-12 15:00:26 +01:00
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.
/// This allows to correct effects where results are too "packed" around a certain value.
2024-03-25 10:05:38 +01:00
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(from = "DistributionShiftSerializable")]
#[serde(into = "DistributionShiftSerializable")]
pub struct DistributionShift {
2024-03-12 15:00:26 +01:00
/// Value where the results are "packed".
///
/// Similarity scores are translated so that they are packed around 0.5 instead
2024-03-25 10:05:38 +01:00
pub current_mean: OrderedFloat<f32>,
2024-03-12 15:00:26 +01:00
/// standard deviation of a similarity score.
///
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
2024-03-25 10:05:38 +01:00
pub current_sigma: OrderedFloat<f32>,
}
2024-03-27 11:50:33 +01:00
impl<E> Deserr<E> for DistributionShift
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef<'_>,
2024-03-27 11:50:33 +01:00
) -> Result<Self, E> {
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
if value.mean < 0. || value.mean > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution mean must be in the range [0, 1], got {}",
value.mean
),
},
location,
)));
}
if value.sigma <= 0. || value.sigma > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution sigma must be in the range ]0, 1], got {}",
value.sigma
),
},
location,
)));
}
Ok(value.into())
}
}
#[derive(Serialize, Deserialize, Deserr)]
#[serde(deny_unknown_fields)]
#[deserr(deny_unknown_fields)]
2024-03-25 10:05:38 +01:00
struct DistributionShiftSerializable {
2024-03-27 11:50:33 +01:00
mean: f32,
sigma: f32,
2024-03-25 10:05:38 +01:00
}
impl From<DistributionShift> for DistributionShiftSerializable {
fn from(
DistributionShift {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
2024-03-27 11:50:33 +01:00
Self { mean: current_mean, sigma: current_sigma }
2024-03-25 10:05:38 +01:00
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
2024-03-27 11:50:33 +01:00
fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
2024-03-25 10:05:38 +01:00
}
}
impl DistributionShift {
/// `None` if sigma <= 0.
pub fn new(mean: f32, sigma: f32) -> Option<Self> {
if sigma <= 0.0 {
None
} else {
2024-03-25 10:05:38 +01:00
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
}
}
pub fn shift(&self, score: f32) -> f32 {
2024-03-25 10:05:38 +01:00
let current_mean = self.current_mean.0;
let current_sigma = self.current_sigma.0;
// <https://math.stackexchange.com/a/2894689>
// We're somewhat abusively mapping the distribution of distances to a gaussian.
// The parameters we're given is the mean and sigma of the native result distribution.
// We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4.
let target_mean = 0.5;
let target_sigma = 0.4;
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
2024-03-25 10:05:38 +01:00
let factor = target_sigma / current_sigma;
// a*mu1 + b = mu2 => b = mu2 - a*mu1
2024-03-25 10:05:38 +01:00
let offset = target_mean - (factor * current_mean);
let mut score = factor * score + offset;
// clamp the final score in the ]0, 1] interval.
if score <= 0.0 {
score = f32::EPSILON;
}
if score > 1.0 {
score = 1.0;
}
score
}
}
2024-02-26 10:41:47 +01:00
2024-03-12 15:00:26 +01:00
/// Whether CUDA is supported in this version of Meilisearch.
2024-02-26 10:41:47 +01:00
pub const fn is_cuda_enabled() -> bool {
cfg!(feature = "cuda")
}
pub fn arroy_db_range_for_embedder(embedder_id: u8) -> impl Iterator<Item = u16> {
let embedder_id = (embedder_id as u16) << 8;
(0..=u8::MAX).map(move |k| embedder_id | (k as u16))
}