From 548c8247c29f90e52bfe85caf397586d45dec20a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 28 Nov 2023 10:11:17 +0100 Subject: [PATCH] Create and use real error types in the codecs --- .../facet/field_doc_id_facet_codec.rs | 5 +-- .../src/heed_codec/facet/ordered_f64_codec.rs | 13 ++++++-- .../heed_codec/field_id_word_count_codec.rs | 7 ++--- milli/src/heed_codec/mod.rs | 5 +++ milli/src/heed_codec/script_language_codec.rs | 14 +++------ milli/src/heed_codec/str_beu32_codec.rs | 6 ++-- milli/src/heed_codec/str_str_u8_codec.rs | 31 ++++++++----------- 7 files changed, 43 insertions(+), 38 deletions(-) diff --git a/milli/src/heed_codec/facet/field_doc_id_facet_codec.rs b/milli/src/heed_codec/facet/field_doc_id_facet_codec.rs index a0bea2c42..7e281adfa 100644 --- a/milli/src/heed_codec/facet/field_doc_id_facet_codec.rs +++ b/milli/src/heed_codec/facet/field_doc_id_facet_codec.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use heed::{BoxedError, BytesDecode, BytesEncode}; +use crate::heed_codec::SliceTooShortError; use crate::{try_split_array_at, DocumentId, FieldId}; pub struct FieldDocIdFacetCodec(PhantomData); @@ -14,10 +15,10 @@ where type DItem = (FieldId, DocumentId, C::DItem); fn bytes_decode(bytes: &'a [u8]) -> Result { - let (field_id_bytes, bytes) = try_split_array_at(bytes).unwrap(); + let (field_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?; let field_id = u16::from_be_bytes(field_id_bytes); - let (document_id_bytes, bytes) = try_split_array_at(bytes).unwrap(); + let (document_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?; let document_id = u32::from_be_bytes(document_id_bytes); let value = C::bytes_decode(bytes)?; diff --git a/milli/src/heed_codec/facet/ordered_f64_codec.rs b/milli/src/heed_codec/facet/ordered_f64_codec.rs index 64bb0b0cd..b692b2363 100644 --- a/milli/src/heed_codec/facet/ordered_f64_codec.rs +++ b/milli/src/heed_codec/facet/ordered_f64_codec.rs @@ -2,8 +2,10 @@ use std::borrow::Cow; use std::convert::TryInto; use heed::{BoxedError, BytesDecode}; +use thiserror::Error; use crate::facet::value_encoding::f64_into_bytes; +use crate::heed_codec::SliceTooShortError; pub struct OrderedF64Codec; @@ -12,7 +14,7 @@ impl<'a> BytesDecode<'a> for OrderedF64Codec { fn bytes_decode(bytes: &'a [u8]) -> Result { if bytes.len() < 16 { - Err(BoxedError::from("invalid slice length")) + Err(SliceTooShortError.into()) } else { bytes[8..].try_into().map(f64::from_be_bytes).map_err(Into::into) } @@ -26,8 +28,7 @@ impl heed::BytesEncode<'_> for OrderedF64Codec { let mut buffer = [0u8; 16]; // write the globally ordered float - let bytes = f64_into_bytes(*f) - .ok_or_else(|| BoxedError::from("cannot generate a globally ordered float"))?; + let bytes = f64_into_bytes(*f).ok_or(InvalidGloballyOrderedFloatError { float: *f })?; buffer[..8].copy_from_slice(&bytes[..]); // Then the f64 value just to be able to read it back let bytes = f.to_be_bytes(); @@ -36,3 +37,9 @@ impl heed::BytesEncode<'_> for OrderedF64Codec { Ok(Cow::Owned(buffer.to_vec())) } } + +#[derive(Error, Debug)] +#[error("the float {float} cannot be converted to a globally ordered representation")] +pub struct InvalidGloballyOrderedFloatError { + float: f64, +} diff --git a/milli/src/heed_codec/field_id_word_count_codec.rs b/milli/src/heed_codec/field_id_word_count_codec.rs index 9e7f044c5..19d8d63c6 100644 --- a/milli/src/heed_codec/field_id_word_count_codec.rs +++ b/milli/src/heed_codec/field_id_word_count_codec.rs @@ -2,6 +2,7 @@ use std::borrow::Cow; use heed::BoxedError; +use super::SliceTooShortError; use crate::{try_split_array_at, FieldId}; pub struct FieldIdWordCountCodec; @@ -10,11 +11,9 @@ impl<'a> heed::BytesDecode<'a> for FieldIdWordCountCodec { type DItem = (FieldId, u8); fn bytes_decode(bytes: &'a [u8]) -> Result { - let (field_id_bytes, bytes) = - try_split_array_at(bytes).ok_or("invalid slice length").map_err(BoxedError::from)?; + let (field_id_bytes, bytes) = try_split_array_at(bytes).ok_or(SliceTooShortError)?; let field_id = u16::from_be_bytes(field_id_bytes); - let ([word_count], _nothing) = - try_split_array_at(bytes).ok_or("invalid slice length").map_err(BoxedError::from)?; + let ([word_count], _nothing) = try_split_array_at(bytes).ok_or(SliceTooShortError)?; Ok((field_id, word_count)) } } diff --git a/milli/src/heed_codec/mod.rs b/milli/src/heed_codec/mod.rs index dde77a5f3..449d1955c 100644 --- a/milli/src/heed_codec/mod.rs +++ b/milli/src/heed_codec/mod.rs @@ -15,6 +15,7 @@ mod str_str_u8_codec; pub use byte_slice_ref::BytesRefCodec; use heed::BoxedError; pub use str_ref::StrRefCodec; +use thiserror::Error; pub use self::beu16_str_codec::BEU16StrCodec; pub use self::beu32_str_codec::BEU32StrCodec; @@ -34,3 +35,7 @@ pub trait BytesDecodeOwned { fn bytes_decode_owned(bytes: &[u8]) -> Result; } + +#[derive(Error, Debug)] +#[error("the slice is too short")] +pub struct SliceTooShortError; diff --git a/milli/src/heed_codec/script_language_codec.rs b/milli/src/heed_codec/script_language_codec.rs index 013ec62bb..ef2ad4bec 100644 --- a/milli/src/heed_codec/script_language_codec.rs +++ b/milli/src/heed_codec/script_language_codec.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::ffi::CStr; use std::str; use charabia::{Language, Script}; @@ -10,17 +11,12 @@ impl<'a> heed::BytesDecode<'a> for ScriptLanguageCodec { type DItem = (Script, Language); fn bytes_decode(bytes: &'a [u8]) -> Result { - let sep = bytes - .iter() - .position(|b| *b == 0) - .ok_or("cannot find nul byte") - .map_err(BoxedError::from)?; - let (s_bytes, l_bytes) = bytes.split_at(sep); - let script = str::from_utf8(s_bytes)?; + let cstr = CStr::from_bytes_until_nul(bytes)?; + let script = cstr.to_str()?; let script_name = Script::from_name(script); - let lan = str::from_utf8(l_bytes)?; // skip '\0' byte between the two strings. - let lan_name = Language::from_name(&lan[1..]); + let lan = str::from_utf8(&bytes[script.len() + 1..])?; + let lan_name = Language::from_name(lan); Ok((script_name, lan_name)) } diff --git a/milli/src/heed_codec/str_beu32_codec.rs b/milli/src/heed_codec/str_beu32_codec.rs index c654a1811..c76ea2a26 100644 --- a/milli/src/heed_codec/str_beu32_codec.rs +++ b/milli/src/heed_codec/str_beu32_codec.rs @@ -5,6 +5,8 @@ use std::str; use heed::BoxedError; +use super::SliceTooShortError; + pub struct StrBEU32Codec; impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { @@ -14,7 +16,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { let footer_len = size_of::(); if bytes.len() < footer_len { - return Err(BoxedError::from("cannot extract footer from bytes")); + return Err(SliceTooShortError.into()); } let (word, bytes) = bytes.split_at(bytes.len() - footer_len); @@ -48,7 +50,7 @@ impl<'a> heed::BytesDecode<'a> for StrBEU16Codec { let footer_len = size_of::(); if bytes.len() < footer_len + 1 { - return Err(BoxedError::from("cannot extract footer from bytes")); + return Err(SliceTooShortError.into()); } let (word_plus_nul_byte, bytes) = bytes.split_at(bytes.len() - footer_len); diff --git a/milli/src/heed_codec/str_str_u8_codec.rs b/milli/src/heed_codec/str_str_u8_codec.rs index 743ddb1f7..0aedf0c94 100644 --- a/milli/src/heed_codec/str_str_u8_codec.rs +++ b/milli/src/heed_codec/str_str_u8_codec.rs @@ -1,24 +1,22 @@ use std::borrow::Cow; +use std::ffi::CStr; use std::str; use heed::BoxedError; +use super::SliceTooShortError; + pub struct U8StrStrCodec; impl<'a> heed::BytesDecode<'a> for U8StrStrCodec { type DItem = (u8, &'a str, &'a str); fn bytes_decode(bytes: &'a [u8]) -> Result { - let (n, bytes) = bytes.split_first().ok_or("not enough bytes").map_err(BoxedError::from)?; - let s1_end = bytes - .iter() - .position(|b| *b == 0) - .ok_or("cannot find nul byte") - .map_err(BoxedError::from)?; - let (s1_bytes, rest) = bytes.split_at(s1_end); - let s2_bytes = &rest[1..]; - let s1 = str::from_utf8(s1_bytes)?; - let s2 = str::from_utf8(s2_bytes)?; + let (n, bytes) = bytes.split_first().ok_or(SliceTooShortError)?; + let cstr = CStr::from_bytes_until_nul(bytes)?; + let s1 = cstr.to_str()?; + // skip '\0' byte between the two strings. + let s2 = str::from_utf8(&bytes[s1.len() + 1..])?; Ok((*n, s1, s2)) } } @@ -41,14 +39,11 @@ impl<'a> heed::BytesDecode<'a> for UncheckedU8StrStrCodec { type DItem = (u8, &'a [u8], &'a [u8]); fn bytes_decode(bytes: &'a [u8]) -> Result { - let (n, bytes) = bytes.split_first().ok_or("not enough bytes").map_err(BoxedError::from)?; - let s1_end = bytes - .iter() - .position(|b| *b == 0) - .ok_or("cannot find nul byte") - .map_err(BoxedError::from)?; - let (s1_bytes, rest) = bytes.split_at(s1_end); - let s2_bytes = &rest[1..]; + let (n, bytes) = bytes.split_first().ok_or(SliceTooShortError)?; + let cstr = CStr::from_bytes_until_nul(bytes)?; + let s1_bytes = cstr.to_bytes(); + // skip '\0' byte between the two strings. + let s2_bytes = &bytes[s1_bytes.len() + 1..]; Ok((*n, s1_bytes, s2_bytes)) } }