HuggingFace: Clearer error message when a model is not supported

This commit is contained in:
Louis Dureuil 2024-07-23 15:05:45 +02:00
parent f6d2c59bca
commit 303e601b87
No known key found for this signature in database
2 changed files with 53 additions and 8 deletions

View File

@ -217,14 +217,39 @@ impl NewEmbedderError {
} }
pub fn deserialize_config( pub fn deserialize_config(
model_name: String,
config: String, config: String,
config_filename: PathBuf, config_filename: PathBuf,
inner: serde_json::Error, inner: serde_json::Error,
) -> NewEmbedderError { ) -> NewEmbedderError {
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; match serde_json::from_str(&config) {
Self { Ok(value) => {
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), let value: serde_json::Value = value;
fault: FaultSource::Runtime, let architectures = match value.get("architectures") {
Some(serde_json::Value::Array(architectures)) => architectures
.iter()
.filter_map(|value| match value {
serde_json::Value::String(s) => Some(s.to_owned()),
_ => None,
})
.collect(),
_ => vec![],
};
let unsupported_model = UnsupportedModel { model_name, inner, architectures };
Self {
kind: NewEmbedderErrorKind::UnsupportedModel(unsupported_model),
fault: FaultSource::User,
}
}
Err(error) => {
let deserialize_config =
DeserializeConfig { model_name, filename: config_filename, inner: error };
Self {
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
fault: FaultSource::Runtime,
}
}
} }
} }
@ -252,7 +277,7 @@ impl NewEmbedderError {
} }
pub fn safetensor_weight(inner: candle_core::Error) -> Self { pub fn safetensor_weight(inner: candle_core::Error) -> Self {
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } Self { kind: NewEmbedderErrorKind::SafetensorWeight(inner), fault: FaultSource::Runtime }
} }
pub fn load_model(inner: candle_core::Error) -> Self { pub fn load_model(inner: candle_core::Error) -> Self {
@ -275,13 +300,26 @@ pub struct OpenConfig {
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] #[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
pub struct DeserializeConfig { pub struct DeserializeConfig {
pub config: String, pub model_name: String,
pub filename: PathBuf, pub filename: PathBuf,
pub inner: serde_json::Error, pub inner: serde_json::Error,
} }
#[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() {
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
} else {
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
})]
pub struct UnsupportedModel {
pub model_name: String,
pub inner: serde_json::Error,
pub architectures: Vec<String>,
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
#[error("could not open tokenizer at {filename}: {inner}")] #[error("could not open tokenizer at {filename}: {inner}")]
pub struct OpenTokenizer { pub struct OpenTokenizer {
@ -298,6 +336,8 @@ pub enum NewEmbedderErrorKind {
#[error(transparent)] #[error(transparent)]
DeserializeConfig(DeserializeConfig), DeserializeConfig(DeserializeConfig),
#[error(transparent)] #[error(transparent)]
UnsupportedModel(UnsupportedModel),
#[error(transparent)]
OpenTokenizer(OpenTokenizer), OpenTokenizer(OpenTokenizer),
#[error("could not build weights from Pytorch weights: {0}")] #[error("could not build weights from Pytorch weights: {0}")]
PytorchWeight(candle_core::Error), PytorchWeight(candle_core::Error),

View File

@ -103,7 +103,12 @@ impl Embedder {
let config = std::fs::read_to_string(&config_filename) let config = std::fs::read_to_string(&config_filename)
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
let config: Config = serde_json::from_str(&config).map_err(|inner| { let config: Config = serde_json::from_str(&config).map_err(|inner| {
NewEmbedderError::deserialize_config(config, config_filename, inner) NewEmbedderError::deserialize_config(
options.model.clone(),
config,
config_filename,
inner,
)
})?; })?;
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;