mirror of
https://github.com/meilisearch/meilisearch.git
synced 2025-03-06 22:02:34 +08:00
Support pooling
This commit is contained in:
parent
0f1aeb8eaa
commit
11759c4be4
@ -262,6 +262,31 @@ impl NewEmbedderError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn open_pooling_config(
|
||||||
|
pooling_config_filename: PathBuf,
|
||||||
|
inner: std::io::Error,
|
||||||
|
) -> NewEmbedderError {
|
||||||
|
let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner };
|
||||||
|
|
||||||
|
Self {
|
||||||
|
kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_pooling_config(
|
||||||
|
model_name: String,
|
||||||
|
pooling_config_filename: PathBuf,
|
||||||
|
inner: serde_json::Error,
|
||||||
|
) -> NewEmbedderError {
|
||||||
|
let deserialize_pooling_config =
|
||||||
|
DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner };
|
||||||
|
Self {
|
||||||
|
kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn open_tokenizer(
|
pub fn open_tokenizer(
|
||||||
tokenizer_filename: PathBuf,
|
tokenizer_filename: PathBuf,
|
||||||
inner: Box<dyn std::error::Error + Send + Sync>,
|
inner: Box<dyn std::error::Error + Send + Sync>,
|
||||||
@ -319,6 +344,13 @@ pub struct OpenConfig {
|
|||||||
pub inner: std::io::Error,
|
pub inner: std::io::Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("could not open pooling config at {filename}: {inner}")]
|
||||||
|
pub struct OpenPoolingConfig {
|
||||||
|
pub filename: PathBuf,
|
||||||
|
pub inner: std::io::Error,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
|
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
|
||||||
pub struct DeserializeConfig {
|
pub struct DeserializeConfig {
|
||||||
@ -327,6 +359,14 @@ pub struct DeserializeConfig {
|
|||||||
pub inner: serde_json::Error,
|
pub inner: serde_json::Error,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")]
|
||||||
|
pub struct DeserializePoolingConfig {
|
||||||
|
pub model_name: String,
|
||||||
|
pub filename: PathBuf,
|
||||||
|
pub inner: serde_json::Error,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
||||||
if architectures.is_empty() {
|
if architectures.is_empty() {
|
||||||
@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind {
|
|||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
OpenConfig(OpenConfig),
|
OpenConfig(OpenConfig),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
OpenPoolingConfig(OpenPoolingConfig),
|
||||||
|
#[error(transparent)]
|
||||||
DeserializeConfig(DeserializeConfig),
|
DeserializeConfig(DeserializeConfig),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
DeserializePoolingConfig(DeserializePoolingConfig),
|
||||||
|
#[error(transparent)]
|
||||||
UnsupportedModel(UnsupportedModel),
|
UnsupportedModel(UnsupportedModel),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
OpenTokenizer(OpenTokenizer),
|
OpenTokenizer(OpenTokenizer),
|
||||||
|
@ -58,6 +58,7 @@ pub struct Embedder {
|
|||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
options: EmbedderOptions,
|
options: EmbedderOptions,
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
|
pooling: Pooling,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Embedder {
|
impl std::fmt::Debug for Embedder {
|
||||||
@ -66,10 +67,53 @@ impl std::fmt::Debug for Embedder {
|
|||||||
.field("model", &self.options.model)
|
.field("model", &self.options.model)
|
||||||
.field("tokenizer", &self.tokenizer)
|
.field("tokenizer", &self.tokenizer)
|
||||||
.field("options", &self.options)
|
.field("options", &self.options)
|
||||||
|
.field("pooling", &self.pooling)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, serde::Deserialize)]
|
||||||
|
struct PoolingConfig {
|
||||||
|
#[serde(default)]
|
||||||
|
pub pooling_mode_cls_token: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pooling_mode_mean_tokens: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pooling_mode_max_tokens: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pooling_mode_mean_sqrt_len_tokens: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub pooling_mode_lasttoken: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub enum Pooling {
|
||||||
|
#[default]
|
||||||
|
Mean,
|
||||||
|
Cls,
|
||||||
|
Max,
|
||||||
|
MeanSqrtLen,
|
||||||
|
LastToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<PoolingConfig> for Pooling {
|
||||||
|
fn from(value: PoolingConfig) -> Self {
|
||||||
|
if value.pooling_mode_cls_token {
|
||||||
|
Self::Cls
|
||||||
|
} else if value.pooling_mode_mean_tokens {
|
||||||
|
Self::Mean
|
||||||
|
} else if value.pooling_mode_lasttoken {
|
||||||
|
Self::LastToken
|
||||||
|
} else if value.pooling_mode_mean_sqrt_len_tokens {
|
||||||
|
Self::MeanSqrtLen
|
||||||
|
} else if value.pooling_mode_max_tokens {
|
||||||
|
Self::Max
|
||||||
|
} else {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||||
let device = match candle_core::Device::cuda_if_available(0) {
|
let device = match candle_core::Device::cuda_if_available(0) {
|
||||||
@ -83,7 +127,7 @@ impl Embedder {
|
|||||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||||
None => Repo::model(options.model.clone()),
|
None => Repo::model(options.model.clone()),
|
||||||
};
|
};
|
||||||
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
|
let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = {
|
||||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||||
@ -97,7 +141,36 @@ impl Embedder {
|
|||||||
})
|
})
|
||||||
.map_err(NewEmbedderError::api_get)?
|
.map_err(NewEmbedderError::api_get)?
|
||||||
};
|
};
|
||||||
(config, tokenizer, weights, source)
|
let pooling = match api.get("1_Pooling/config.json") {
|
||||||
|
Ok(pooling) => Some(pooling),
|
||||||
|
Err(hf_hub::api::sync::ApiError::RequestError(error))
|
||||||
|
if matches!(*error, ureq::Error::Status(404, _,)) =>
|
||||||
|
{
|
||||||
|
// ignore the error if the file simply doesn't exist
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Err(error) => return Err(NewEmbedderError::api_get(error)),
|
||||||
|
};
|
||||||
|
let pooling: Pooling = match pooling {
|
||||||
|
Some(pooling_filename) => {
|
||||||
|
let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
|
||||||
|
NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let pooling: PoolingConfig =
|
||||||
|
serde_json::from_str(&pooling).map_err(|inner| {
|
||||||
|
NewEmbedderError::deserialize_pooling_config(
|
||||||
|
options.model.clone(),
|
||||||
|
pooling_filename,
|
||||||
|
inner,
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
pooling.into()
|
||||||
|
}
|
||||||
|
None => Pooling::default(),
|
||||||
|
};
|
||||||
|
|
||||||
|
(config, tokenizer, weights, source, pooling)
|
||||||
};
|
};
|
||||||
|
|
||||||
let config = std::fs::read_to_string(&config_filename)
|
let config = std::fs::read_to_string(&config_filename)
|
||||||
@ -122,6 +195,8 @@ impl Embedder {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
|
||||||
|
|
||||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||||
|
|
||||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
@ -134,7 +209,7 @@ impl Embedder {
|
|||||||
tokenizer.with_padding(Some(pp));
|
tokenizer.with_padding(Some(pp));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut this = Self { model, tokenizer, options, dimensions: 0 };
|
let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
|
||||||
|
|
||||||
let embeddings = this
|
let embeddings = this
|
||||||
.embed(vec!["test".into()])
|
.embed(vec!["test".into()])
|
||||||
@ -168,17 +243,53 @@ impl Embedder {
|
|||||||
.forward(&token_ids, &token_type_ids, None)
|
.forward(&token_ids, &token_type_ids, None)
|
||||||
.map_err(EmbedError::model_forward)?;
|
.map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
let embeddings = Self::pooling(embeddings, self.pooling)?;
|
||||||
let (_n_sentence, n_tokens, _hidden_size) =
|
|
||||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
|
||||||
|
|
||||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
|
||||||
.map_err(EmbedError::tensor_shape)?;
|
|
||||||
|
|
||||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
|
||||||
|
match pooling {
|
||||||
|
Pooling::Mean => Self::mean_pooling(embeddings),
|
||||||
|
Pooling::Cls => Self::cls_pooling(embeddings),
|
||||||
|
Pooling::Max => Self::max_pooling(embeddings),
|
||||||
|
Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings),
|
||||||
|
Pooling::LastToken => Self::last_token_pooling(embeddings),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
|
||||||
|
embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
|
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
|
(embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt())
|
||||||
|
.map_err(EmbedError::tensor_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
|
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
|
(embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||||
|
.map_err(EmbedError::tensor_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
|
||||||
|
embeddings.max(1).map_err(EmbedError::tensor_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
|
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
|
embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
|
||||||
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
|
||||||
let token_ids = tokens.get_ids();
|
let token_ids = tokens.get_ids();
|
||||||
@ -192,11 +303,8 @@ impl Embedder {
|
|||||||
.forward(&token_ids, &token_type_ids, None)
|
.forward(&token_ids, &token_type_ids, None)
|
||||||
.map_err(EmbedError::model_forward)?;
|
.map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
let embedding = Self::pooling(embeddings, self.pooling)?;
|
||||||
let (_n_sentence, n_tokens, _hidden_size) =
|
|
||||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
|
||||||
let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
|
||||||
.map_err(EmbedError::tensor_shape)?;
|
|
||||||
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
|
let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
|
||||||
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
|
||||||
Ok(embedding)
|
Ok(embedding)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user