diff --git a/src/aead/gcm.rs b/src/aead/gcm.rs index 7fcfd88d8..0940bf43b 100644 --- a/src/aead/gcm.rs +++ b/src/aead/gcm.rs @@ -16,7 +16,7 @@ use self::ffi::{Block, BLOCK_LEN, ZERO_BLOCK}; use super::{aes_gcm, Aad}; use crate::{ bits::{BitLength, FromByteLen as _}, - error, + error::{self, InputTooLongError}, polyfill::{sliceutil::overwrite_at_start, NotSend}, }; use cfg_if::cfg_if; @@ -57,8 +57,10 @@ impl<'key, K: Gmult> Context<'key, K> { if in_out_len > aes_gcm::MAX_IN_OUT_LEN { return Err(error::Unspecified); } - let in_out_len = BitLength::from_byte_len(in_out_len)?; - let aad_len = BitLength::from_byte_len(aad.as_ref().len())?; + let in_out_len = BitLength::from_byte_len(in_out_len) + .map_err(|_: InputTooLongError| error::Unspecified)?; + let aad_len = BitLength::from_byte_len(aad.as_ref().len()) + .map_err(|_: InputTooLongError| error::Unspecified)?; // NIST SP800-38D Section 5.2.1.1 says that the maximum AAD length is // 2**64 - 1 bits, i.e. BitLength::MAX, so we don't need to do an diff --git a/src/bits.rs b/src/bits.rs index e719eebbf..fc3d0c602 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -14,7 +14,7 @@ //! Bit lengths. -use crate::{error, polyfill}; +use crate::{error::InputTooLongError, polyfill}; /// The length of something, in bits. /// @@ -27,36 +27,35 @@ pub(crate) trait FromByteLen: Sized { /// Constructs a `BitLength` from the given length in bytes. /// /// Fails if `bytes * 8` is too large for a `T`. - fn from_byte_len(bytes: T) -> Result; + fn from_byte_len(bytes: T) -> Result>; } impl FromByteLen for BitLength { #[inline] - fn from_byte_len(bytes: usize) -> Result { + fn from_byte_len(bytes: usize) -> Result { match bytes.checked_mul(8) { Some(bits) => Ok(Self(bits)), - None => Err(error::Unspecified), + None => Err(InputTooLongError::new(bytes)), } } } impl FromByteLen for BitLength { #[inline] - fn from_byte_len(bytes: u64) -> Result { + fn from_byte_len(bytes: u64) -> Result> { match bytes.checked_mul(8) { Some(bits) => Ok(Self(bits)), - None => Err(error::Unspecified), + None => Err(InputTooLongError::new(bytes)), } } } impl FromByteLen for BitLength { #[inline] - fn from_byte_len(bytes: usize) -> Result { - let bytes = polyfill::u64_from_usize(bytes); - match bytes.checked_mul(8) { + fn from_byte_len(bytes: usize) -> Result> { + match polyfill::u64_from_usize(bytes).checked_mul(8) { Some(bits) => Ok(Self(bits)), - None => Err(error::Unspecified), + None => Err(InputTooLongError::new(bytes)), } } } @@ -102,8 +101,8 @@ impl BitLength { #[cfg(feature = "alloc")] #[inline] - pub(crate) fn try_sub_1(self) -> Result { - let sum = self.0.checked_sub(1).ok_or(error::Unspecified)?; + pub(crate) fn try_sub_1(self) -> Result { + let sum = self.0.checked_sub(1).ok_or(crate::error::Unspecified)?; Ok(Self(sum)) } } diff --git a/src/digest.rs b/src/digest.rs index 23145ebdc..2eec80eaa 100644 --- a/src/digest.rs +++ b/src/digest.rs @@ -24,7 +24,8 @@ use self::{ }; use crate::{ bits::{BitLength, FromByteLen as _}, - cpu, debug, error, + cpu, debug, + error::{self, InputTooLongError}, polyfill::{self, slice, sliceutil}, }; use core::num::Wrapping; @@ -80,12 +81,15 @@ impl BlockContext { num_pending: usize, cpu_features: cpu::Features, ) -> Result { + // XXX: Choosing self.completed_bytes when the addition overflows is + // arbitrary. let completed_bytes = self .completed_bytes .checked_add(polyfill::u64_from_usize(num_pending)) - .ok_or_else(|| FinishError::too_much_input(self.completed_bytes))?; - let completed_bits = BitLength::from_byte_len(completed_bytes) - .map_err(|_: error::Unspecified| FinishError::too_much_input(self.completed_bytes))?; + .ok_or_else(|| InputTooLongError::new(self.completed_bytes)) + .map_err(FinishError::InputTooLong)?; + let completed_bits = + BitLength::from_byte_len(completed_bytes).map_err(FinishError::InputTooLong)?; let block_len = self.algorithm.block_len(); let block = &mut block[..block_len]; @@ -143,18 +147,12 @@ impl BlockContext { pub(crate) enum FinishError { #[allow(dead_code)] - TooMuchInput(u64), + InputTooLong(InputTooLongError), #[allow(dead_code)] PendingNotAPartialBlock(usize), } impl FinishError { - #[cold] - #[inline(never)] - fn too_much_input(completed_bytes: u64) -> Self { - Self::TooMuchInput(completed_bytes) - } - // unreachable #[cold] #[inline(never)] diff --git a/src/error/input_too_long.rs b/src/error/input_too_long.rs new file mode 100644 index 000000000..1a3429a38 --- /dev/null +++ b/src/error/input_too_long.rs @@ -0,0 +1,39 @@ +// Copyright 2024 Brian Smith. +// +// Permission to use, copy, modify, and/or distribute this software for any +// purpose with or without fee is hereby granted, provided that the above +// copyright notice and this permission notice appear in all copies. +// +// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES +// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY +// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION +// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN +// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +pub struct InputTooLongError { + /// Note that this might not actually be the (exact) length of the input, + /// and its units might be lost. For example, it could be any of the + /// following: + /// + /// * The length in bytes of the entire input. + /// * The length in bytes of some *part* of the input. + /// * A bit length. + /// * A length in terms of "blocks" or other grouping of input values. + /// * Some intermediate quantity that was used when checking the input + /// length. + /// * Some arbitrary value. + #[allow(dead_code)] + imprecise_input_length: T, +} + +impl InputTooLongError { + #[cold] + #[inline(never)] + pub(crate) fn new(imprecise_input_length: T) -> Self { + Self { + imprecise_input_length, + } + } +} diff --git a/src/error/into_unspecified.rs b/src/error/into_unspecified.rs index e094879f9..af52e856e 100644 --- a/src/error/into_unspecified.rs +++ b/src/error/into_unspecified.rs @@ -15,19 +15,19 @@ use crate::error::{KeyRejected, Unspecified}; impl From for Unspecified { - fn from(_: untrusted::EndOfInput) -> Self { - Self + fn from(source: untrusted::EndOfInput) -> Self { + super::erase(source) } } impl From for Unspecified { - fn from(_: core::array::TryFromSliceError) -> Self { - Self + fn from(source: core::array::TryFromSliceError) -> Self { + super::erase(source) } } impl From for Unspecified { - fn from(_: KeyRejected) -> Self { - Self + fn from(source: KeyRejected) -> Self { + super::erase(source) } } diff --git a/src/error/mod.rs b/src/error/mod.rs index 5ff2aafe1..ad3384808 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -16,6 +16,15 @@ pub use self::{key_rejected::KeyRejected, unspecified::Unspecified}; +pub(crate) use self::input_too_long::InputTooLongError; + +mod input_too_long; mod into_unspecified; mod key_rejected; mod unspecified; + +#[cold] +#[inline(never)] +pub(crate) fn erase(_: T) -> Unspecified { + Unspecified +} diff --git a/src/rsa/public_modulus.rs b/src/rsa/public_modulus.rs index 81d2b8f28..ad95ce45b 100644 --- a/src/rsa/public_modulus.rs +++ b/src/rsa/public_modulus.rs @@ -45,8 +45,9 @@ impl PublicModulus { // the public modulus to be exactly 2048 or 3072 bits, but we are more // flexible to be compatible with other commonly-used crypto libraries. assert!(min_bits >= MIN_BITS); - let bits_rounded_up = - bits::BitLength::from_byte_len(bits.as_usize_bytes_rounded_up()).unwrap(); // TODO: safe? + let bits_rounded_up = bits::BitLength::from_byte_len(bits.as_usize_bytes_rounded_up()) + .map_err(error::erase) + .unwrap(); // TODO: safe? if bits_rounded_up < min_bits { return Err(error::KeyRejected::too_small()); } diff --git a/src/rsa/verification.rs b/src/rsa/verification.rs index ac4da7728..0624f28af 100644 --- a/src/rsa/verification.rs +++ b/src/rsa/verification.rs @@ -201,7 +201,7 @@ pub(crate) fn verify_rsa_( cpu_features: cpu::Features, ) -> Result<(), error::Unspecified> { let max_bits: bits::BitLength = - bits::BitLength::from_byte_len(PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN)?; + bits::BitLength::from_byte_len(PUBLIC_KEY_PUBLIC_MODULUS_MAX_LEN).map_err(error::erase)?; // XXX: FIPS 186-4 seems to indicate that the minimum // exponent value is 2**16 + 1, but it isn't clear if this is just for diff --git a/src/tests/bits_tests.rs b/src/tests/bits_tests.rs index 0081f89eb..f088f1f7f 100644 --- a/src/tests/bits_tests.rs +++ b/src/tests/bits_tests.rs @@ -13,11 +13,8 @@ // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. use crate::{ + bits::{BitLength, FromByteLen as _}, polyfill::u64_from_usize, - { - bits::{BitLength, FromByteLen as _}, - error, - }, }; #[test] @@ -26,28 +23,40 @@ fn test_from_byte_len_overflow() { // Maximum valid input for BitLength. { - let bits = BitLength::::from_byte_len(USIZE_MAX_VALID_BYTES).unwrap(); - assert_eq!(bits.as_usize_bytes_rounded_up(), USIZE_MAX_VALID_BYTES); - assert_eq!(bits.as_bits(), usize::MAX & !0b111); + match BitLength::::from_byte_len(USIZE_MAX_VALID_BYTES) { + Ok(bits) => { + assert_eq!(bits.as_usize_bytes_rounded_up(), USIZE_MAX_VALID_BYTES); + assert_eq!(bits.as_bits(), usize::MAX & !0b111); + } + Err(_) => { + unreachable!() + } + } } // Minimum invalid usize input for BitLength. - assert_eq!( + assert!(matches!( BitLength::::from_byte_len(USIZE_MAX_VALID_BYTES + 1), - Err(error::Unspecified) - ); + Err(_) + )); // Minimum invalid usize input for BitLength on 64-bit targets. { - let bits = BitLength::::from_byte_len(USIZE_MAX_VALID_BYTES + 1); + let r = BitLength::::from_byte_len(USIZE_MAX_VALID_BYTES + 1); if cfg!(target_pointer_width = "64") { - assert_eq!(bits, Err(error::Unspecified)); + matches!(r, Err(_)); } else { - let bits = bits.unwrap(); - assert_eq!( - bits.as_bits(), - (u64_from_usize(USIZE_MAX_VALID_BYTES) + 1) * 8 - ); + match r { + Ok(bits) => { + assert_eq!( + bits.as_bits(), + (u64_from_usize(USIZE_MAX_VALID_BYTES) + 1) * 8 + ); + } + Err(_) => { + unreachable!() + } + } } } @@ -55,13 +64,19 @@ fn test_from_byte_len_overflow() { // Maximum valid u64 input for BitLength. { - let bits = BitLength::::from_byte_len(U64_MAX_VALID_BYTES).unwrap(); - assert_eq!(bits.as_bits(), u64::MAX & !0b111); + match BitLength::::from_byte_len(U64_MAX_VALID_BYTES) { + Ok(bits) => assert_eq!(bits.as_bits(), u64::MAX & !0b111), + Err(_) => { + unreachable!() + } + }; } // Minimum invalid usize input for BitLength on 64-bit targets. { - let bits = BitLength::::from_byte_len(U64_MAX_VALID_BYTES + 1); - assert_eq!(bits, Err(error::Unspecified)); + assert!(matches!( + BitLength::::from_byte_len(U64_MAX_VALID_BYTES + 1), + Err(_) + )); } }