From 11759c4be4be37a5bf41d8ff62059b0639949cc3 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 18 Feb 2025 14:16:41 +0100 Subject: [PATCH] Support pooling --- crates/milli/src/vector/error.rs | 44 ++++++++++ crates/milli/src/vector/hf.rs | 136 +++++++++++++++++++++++++++---- 2 files changed, 166 insertions(+), 14 deletions(-) diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs index d1b2516f5..650249bff 100644 --- a/crates/milli/src/vector/error.rs +++ b/crates/milli/src/vector/error.rs @@ -262,6 +262,31 @@ impl NewEmbedderError { } } + pub fn open_pooling_config( + pooling_config_filename: PathBuf, + inner: std::io::Error, + ) -> NewEmbedderError { + let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner }; + + Self { + kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config), + fault: FaultSource::Runtime, + } + } + + pub fn deserialize_pooling_config( + model_name: String, + pooling_config_filename: PathBuf, + inner: serde_json::Error, + ) -> NewEmbedderError { + let deserialize_pooling_config = + DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner }; + Self { + kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config), + fault: FaultSource::Runtime, + } + } + pub fn open_tokenizer( tokenizer_filename: PathBuf, inner: Box, @@ -319,6 +344,13 @@ pub struct OpenConfig { pub inner: std::io::Error, } +#[derive(Debug, thiserror::Error)] +#[error("could not open pooling config at {filename}: {inner}")] +pub struct OpenPoolingConfig { + pub filename: PathBuf, + pub inner: std::io::Error, +} + #[derive(Debug, thiserror::Error)] #[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")] pub struct DeserializeConfig { @@ -327,6 +359,14 @@ pub struct DeserializeConfig { pub inner: serde_json::Error, } +#[derive(Debug, thiserror::Error)] +#[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")] +pub struct DeserializePoolingConfig { + pub model_name: String, + pub filename: PathBuf, + pub inner: serde_json::Error, +} + #[derive(Debug, thiserror::Error)] #[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}", if architectures.is_empty() { @@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind { #[error(transparent)] OpenConfig(OpenConfig), #[error(transparent)] + OpenPoolingConfig(OpenPoolingConfig), + #[error(transparent)] DeserializeConfig(DeserializeConfig), #[error(transparent)] + DeserializePoolingConfig(DeserializePoolingConfig), + #[error(transparent)] UnsupportedModel(UnsupportedModel), #[error(transparent)] OpenTokenizer(OpenTokenizer), diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs index 447a88f5d..9ec34daef 100644 --- a/crates/milli/src/vector/hf.rs +++ b/crates/milli/src/vector/hf.rs @@ -58,6 +58,7 @@ pub struct Embedder { tokenizer: Tokenizer, options: EmbedderOptions, dimensions: usize, + pooling: Pooling, } impl std::fmt::Debug for Embedder { @@ -66,10 +67,53 @@ impl std::fmt::Debug for Embedder { .field("model", &self.options.model) .field("tokenizer", &self.tokenizer) .field("options", &self.options) + .field("pooling", &self.pooling) .finish() } } +#[derive(Clone, Copy, serde::Deserialize)] +struct PoolingConfig { + #[serde(default)] + pub pooling_mode_cls_token: bool, + #[serde(default)] + pub pooling_mode_mean_tokens: bool, + #[serde(default)] + pub pooling_mode_max_tokens: bool, + #[serde(default)] + pub pooling_mode_mean_sqrt_len_tokens: bool, + #[serde(default)] + pub pooling_mode_lasttoken: bool, +} + +#[derive(Debug, Clone, Copy, Default)] +pub enum Pooling { + #[default] + Mean, + Cls, + Max, + MeanSqrtLen, + LastToken, +} + +impl From for Pooling { + fn from(value: PoolingConfig) -> Self { + if value.pooling_mode_cls_token { + Self::Cls + } else if value.pooling_mode_mean_tokens { + Self::Mean + } else if value.pooling_mode_lasttoken { + Self::LastToken + } else if value.pooling_mode_mean_sqrt_len_tokens { + Self::MeanSqrtLen + } else if value.pooling_mode_max_tokens { + Self::Max + } else { + Self::default() + } + } +} + impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { let device = match candle_core::Device::cuda_if_available(0) { @@ -83,7 +127,7 @@ impl Embedder { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()), }; - let (config_filename, tokenizer_filename, weights_filename, weight_source) = { + let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = { let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; let api = api.repo(repo); let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; @@ -97,7 +141,36 @@ impl Embedder { }) .map_err(NewEmbedderError::api_get)? }; - (config, tokenizer, weights, source) + let pooling = match api.get("1_Pooling/config.json") { + Ok(pooling) => Some(pooling), + Err(hf_hub::api::sync::ApiError::RequestError(error)) + if matches!(*error, ureq::Error::Status(404, _,)) => + { + // ignore the error if the file simply doesn't exist + None + } + Err(error) => return Err(NewEmbedderError::api_get(error)), + }; + let pooling: Pooling = match pooling { + Some(pooling_filename) => { + let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { + NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) + })?; + + let pooling: PoolingConfig = + serde_json::from_str(&pooling).map_err(|inner| { + NewEmbedderError::deserialize_pooling_config( + options.model.clone(), + pooling_filename, + inner, + ) + })?; + pooling.into() + } + None => Pooling::default(), + }; + + (config, tokenizer, weights, source, pooling) }; let config = std::fs::read_to_string(&config_filename) @@ -122,6 +195,8 @@ impl Embedder { }, }; + tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config"); + let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; if let Some(pp) = tokenizer.get_padding_mut() { @@ -134,7 +209,7 @@ impl Embedder { tokenizer.with_padding(Some(pp)); } - let mut this = Self { model, tokenizer, options, dimensions: 0 }; + let mut this = Self { model, tokenizer, options, dimensions: 0, pooling }; let embeddings = this .embed(vec!["test".into()]) @@ -168,17 +243,53 @@ impl Embedder { .forward(&token_ids, &token_type_ids, None) .map_err(EmbedError::model_forward)?; - // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (_n_sentence, n_tokens, _hidden_size) = - embeddings.dims3().map_err(EmbedError::tensor_shape)?; - - let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) - .map_err(EmbedError::tensor_shape)?; + let embeddings = Self::pooling(embeddings, self.pooling)?; let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; Ok(embeddings) } + fn pooling(embeddings: Tensor, pooling: Pooling) -> Result { + match pooling { + Pooling::Mean => Self::mean_pooling(embeddings), + Pooling::Cls => Self::cls_pooling(embeddings), + Pooling::Max => Self::max_pooling(embeddings), + Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings), + Pooling::LastToken => Self::last_token_pooling(embeddings), + } + } + + fn cls_pooling(embeddings: Tensor) -> Result { + embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value) + } + + fn mean_sqrt_pooling(embeddings: Tensor) -> Result { + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt()) + .map_err(EmbedError::tensor_shape) + } + + fn mean_pooling(embeddings: Tensor) -> Result { + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape) + } + + fn max_pooling(embeddings: Tensor) -> Result { + embeddings.max(1).map_err(EmbedError::tensor_shape) + } + + fn last_token_pooling(embeddings: Tensor) -> Result { + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value) + } + pub fn embed_one(&self, text: &str) -> std::result::Result { let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; let token_ids = tokens.get_ids(); @@ -192,11 +303,8 @@ impl Embedder { .forward(&token_ids, &token_type_ids, None) .map_err(EmbedError::model_forward)?; - // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) - let (_n_sentence, n_tokens, _hidden_size) = - embeddings.dims3().map_err(EmbedError::tensor_shape)?; - let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) - .map_err(EmbedError::tensor_shape)?; + let embedding = Self::pooling(embeddings, self.pooling)?; + let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; Ok(embedding)