Support pooling

This commit is contained in:
Louis Dureuil 2025-02-18 14:16:41 +01:00
parent 0f1aeb8eaa
commit 11759c4be4
No known key found for this signature in database
2 changed files with 166 additions and 14 deletions

View File

@ -262,6 +262,31 @@ impl NewEmbedderError {
}
}
pub fn open_pooling_config(
pooling_config_filename: PathBuf,
inner: std::io::Error,
) -> NewEmbedderError {
let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner };
Self {
kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config),
fault: FaultSource::Runtime,
}
}
pub fn deserialize_pooling_config(
model_name: String,
pooling_config_filename: PathBuf,
inner: serde_json::Error,
) -> NewEmbedderError {
let deserialize_pooling_config =
DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner };
Self {
kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config),
fault: FaultSource::Runtime,
}
}
pub fn open_tokenizer(
tokenizer_filename: PathBuf,
inner: Box<dyn std::error::Error + Send + Sync>,
@ -319,6 +344,13 @@ pub struct OpenConfig {
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not open pooling config at {filename}: {inner}")]
pub struct OpenPoolingConfig {
pub filename: PathBuf,
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
pub struct DeserializeConfig {
@ -327,6 +359,14 @@ pub struct DeserializeConfig {
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")]
pub struct DeserializePoolingConfig {
pub model_name: String,
pub filename: PathBuf,
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() {
@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind {
#[error(transparent)]
OpenConfig(OpenConfig),
#[error(transparent)]
OpenPoolingConfig(OpenPoolingConfig),
#[error(transparent)]
DeserializeConfig(DeserializeConfig),
#[error(transparent)]
DeserializePoolingConfig(DeserializePoolingConfig),
#[error(transparent)]
UnsupportedModel(UnsupportedModel),
#[error(transparent)]
OpenTokenizer(OpenTokenizer),

View File

@ -58,6 +58,7 @@ pub struct Embedder {
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
pooling: Pooling,
}
impl std::fmt::Debug for Embedder {
@ -66,10 +67,53 @@ impl std::fmt::Debug for Embedder {
.field("model", &self.options.model)
.field("tokenizer", &self.tokenizer)
.field("options", &self.options)
.field("pooling", &self.pooling)
.finish()
}
}
#[derive(Clone, Copy, serde::Deserialize)]
struct PoolingConfig {
#[serde(default)]
pub pooling_mode_cls_token: bool,
#[serde(default)]
pub pooling_mode_mean_tokens: bool,
#[serde(default)]
pub pooling_mode_max_tokens: bool,
#[serde(default)]
pub pooling_mode_mean_sqrt_len_tokens: bool,
#[serde(default)]
pub pooling_mode_lasttoken: bool,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum Pooling {
#[default]
Mean,
Cls,
Max,
MeanSqrtLen,
LastToken,
}
impl From<PoolingConfig> for Pooling {
fn from(value: PoolingConfig) -> Self {
if value.pooling_mode_cls_token {
Self::Cls
} else if value.pooling_mode_mean_tokens {
Self::Mean
} else if value.pooling_mode_lasttoken {
Self::LastToken
} else if value.pooling_mode_mean_sqrt_len_tokens {
Self::MeanSqrtLen
} else if value.pooling_mode_max_tokens {
Self::Max
} else {
Self::default()
}
}
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = match candle_core::Device::cuda_if_available(0) {
@ -83,7 +127,7 @@ impl Embedder {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
@ -97,7 +141,36 @@ impl Embedder {
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights, source)
let pooling = match api.get("1_Pooling/config.json") {
Ok(pooling) => Some(pooling),
Err(hf_hub::api::sync::ApiError::RequestError(error))
if matches!(*error, ureq::Error::Status(404, _,)) =>
{
// ignore the error if the file simply doesn't exist
None
}
Err(error) => return Err(NewEmbedderError::api_get(error)),
};
let pooling: Pooling = match pooling {
Some(pooling_filename) => {
let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
})?;
let pooling: PoolingConfig =
serde_json::from_str(&pooling).map_err(|inner| {
NewEmbedderError::deserialize_pooling_config(
options.model.clone(),
pooling_filename,
inner,
)
})?;
pooling.into()
}
None => Pooling::default(),
};
(config, tokenizer, weights, source, pooling)
};
let config = std::fs::read_to_string(&config_filename)
@ -122,6 +195,8 @@ impl Embedder {
},
};
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
if let Some(pp) = tokenizer.get_padding_mut() {
@ -134,7 +209,7 @@ impl Embedder {
tokenizer.with_padding(Some(pp));
}
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
let embeddings = this
.embed(vec!["test".into()])
@ -168,17 +243,53 @@ impl Embedder {
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings = Self::pooling(embeddings, self.pooling)?;
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings)
}
fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
match pooling {
Pooling::Mean => Self::mean_pooling(embeddings),
Pooling::Cls => Self::cls_pooling(embeddings),
Pooling::Max => Self::max_pooling(embeddings),
Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings),
Pooling::LastToken => Self::last_token_pooling(embeddings),
}
}
fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value)
}
fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
(embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt())
.map_err(EmbedError::tensor_shape)
}
fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
(embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)
}
fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
embeddings.max(1).map_err(EmbedError::tensor_shape)
}
fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value)
}
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
let token_ids = tokens.get_ids();
@ -192,11 +303,8 @@ impl Embedder {
.forward(&token_ids, &token_type_ids, None)
.map_err(EmbedError::model_forward)?;
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) =
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embedding = Self::pooling(embeddings, self.pooling)?;
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
Ok(embedding)