From 303e601b877f1477b2e023d8c51f21f08c163efc Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 23 Jul 2024 15:05:45 +0200 Subject: [PATCH] HuggingFace: Clearer error message when a model is not supported --- milli/src/vector/error.rs | 54 ++++++++++++++++++++++++++++++++++----- milli/src/vector/hf.rs | 7 ++++- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index af9718f08..975561dc3 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -217,14 +217,39 @@ impl NewEmbedderError { } pub fn deserialize_config( + model_name: String, config: String, config_filename: PathBuf, inner: serde_json::Error, ) -> NewEmbedderError { - let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; - Self { - kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), - fault: FaultSource::Runtime, + match serde_json::from_str(&config) { + Ok(value) => { + let value: serde_json::Value = value; + let architectures = match value.get("architectures") { + Some(serde_json::Value::Array(architectures)) => architectures + .iter() + .filter_map(|value| match value { + serde_json::Value::String(s) => Some(s.to_owned()), + _ => None, + }) + .collect(), + _ => vec![], + }; + + let unsupported_model = UnsupportedModel { model_name, inner, architectures }; + Self { + kind: NewEmbedderErrorKind::UnsupportedModel(unsupported_model), + fault: FaultSource::User, + } + } + Err(error) => { + let deserialize_config = + DeserializeConfig { model_name, filename: config_filename, inner: error }; + Self { + kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), + fault: FaultSource::Runtime, + } + } } } @@ -252,7 +277,7 @@ impl NewEmbedderError { } pub fn safetensor_weight(inner: candle_core::Error) -> Self { - Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + Self { kind: NewEmbedderErrorKind::SafetensorWeight(inner), fault: FaultSource::Runtime } } pub fn load_model(inner: candle_core::Error) -> Self { @@ -275,13 +300,26 @@ pub struct OpenConfig { } #[derive(Debug, thiserror::Error)] -#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] +#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")] pub struct DeserializeConfig { - pub config: String, + pub model_name: String, pub filename: PathBuf, pub inner: serde_json::Error, } +#[derive(Debug, thiserror::Error)] +#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}", +if architectures.is_empty() { + "\n - Note: only models with architecture \"BertModel\" are supported.".to_string() +} else { + format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.") +})] +pub struct UnsupportedModel { + pub model_name: String, + pub inner: serde_json::Error, + pub architectures: Vec, +} + #[derive(Debug, thiserror::Error)] #[error("could not open tokenizer at {filename}: {inner}")] pub struct OpenTokenizer { @@ -298,6 +336,8 @@ pub enum NewEmbedderErrorKind { #[error(transparent)] DeserializeConfig(DeserializeConfig), #[error(transparent)] + UnsupportedModel(UnsupportedModel), + #[error(transparent)] OpenTokenizer(OpenTokenizer), #[error("could not build weights from Pytorch weights: {0}")] PytorchWeight(candle_core::Error), diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 58181941b..dc1e7d324 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -103,7 +103,12 @@ impl Embedder { let config = std::fs::read_to_string(&config_filename) .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; let config: Config = serde_json::from_str(&config).map_err(|inner| { - NewEmbedderError::deserialize_config(config, config_filename, inner) + NewEmbedderError::deserialize_config( + options.model.clone(), + config, + config_filename, + inner, + ) })?; let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;