diff --git a/src/arithmetic/bigint.rs b/src/arithmetic/bigint.rs index 5fab9d136b..37ec586756 100644 --- a/src/arithmetic/bigint.rs +++ b/src/arithmetic/bigint.rs @@ -166,17 +166,7 @@ where // r *= 2. fn elem_double(r: &mut Elem, m: &Modulus) { - prefixed_extern! { - fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t); - } - unsafe { - LIMBS_shl_mod( - r.limbs.as_mut_ptr(), - r.limbs.as_ptr(), - m.limbs().as_ptr(), - m.limbs().len(), - ); - } + limb::limbs_double_mod(&mut r.limbs, m.limbs()) } // TODO: This is currently unused, but we intend to eventually use this to @@ -286,21 +276,8 @@ impl One { // is correct because R**2 will still be a multiple of the latter as // `N0::LIMBS_USED` is either one or two. fn newRR(m: &Modulus) -> Self { - let r = m.limbs().len() * LIMB_BITS; - - // base = 2**r - n. let mut base = m.zero(); - limb::limbs_negative_odd(&mut base.limbs, m.limbs()); - - // Correct base to 2**(lg m) (mod m). - let lg_m = m.len_bits().as_usize_bits(); - let leading_zero_bits_in_m = r - lg_m; - if leading_zero_bits_in_m != 0 { - debug_assert!(leading_zero_bits_in_m < LIMB_BITS); - // `limbs_negative_odd` flipped all the leading zero bits to ones. - // Flip them back. - *base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m; - } + m.oneR(&mut base.limbs); // Double `base` so that base == R == 2**r (mod m). For normal moduli // that have the high bit of the highest limb set, this requires one @@ -320,8 +297,6 @@ impl One { const LG_BASE: usize = 2; // doubling vs. squaring trade-off. const _LG_BASE_IS_POWER_OF_2: () = assert!(LG_BASE.is_power_of_two()); - let doublings = leading_zero_bits_in_m + LG_BASE; - // r is divisible by LIMB_BITS and LIMB_BITS is divisible by LG_BASE so // r is divisible by LG_BASE. // @@ -333,10 +308,11 @@ impl One { // the Hamming weight is 2. #[allow(clippy::assertions_on_constants)] const _LIMB_BITS_DIVISIBLE_BY_LG_BASE: () = assert!(LIMB_BITS % LG_BASE == 0); + let r = m.limbs().len() * LIMB_BITS; debug_assert_eq!(r % LIMB_BITS, 0); debug_assert_eq!(r % LG_BASE, 0); let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap(); - for _ in 0..doublings { + for _ in 0..LG_BASE { elem_double(&mut base, m) } let RR = elem_exp_vartime(base, exponent, m); @@ -409,11 +385,8 @@ pub(crate) fn elem_exp_vartime( pub fn elem_exp_consttime( base: Elem, exponent: &PrivateExponent, - m: &OwnedModulusWithOne, + m: &Modulus, ) -> Result, error::Unspecified> { - let oneRR = m.oneRR(); - let m = &m.modulus(); - use crate::{bssl, limb::Window}; const WINDOW_BITS: usize = 5; @@ -461,13 +434,7 @@ pub fn elem_exp_consttime( } // table[0] = base**0 (i.e. 1). - { - let acc = entry_mut(&mut table, 0, num_limbs); - // `table` was initialized to zero and hasn't changed. - debug_assert!(acc.iter().all(|&value| value == 0)); - acc[0] = 1; - limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features()); - } + m.oneR(entry_mut(&mut table, 0, num_limbs)); entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs); for i in 2..TABLE_ENTRIES { @@ -504,13 +471,10 @@ pub fn elem_exp_consttime( pub fn elem_exp_consttime( base: Elem, exponent: &PrivateExponent, - m: &OwnedModulusWithOne, + m: &Modulus, ) -> Result, error::Unspecified> { use crate::limb::LIMB_BYTES; - let oneRR = m.oneRR(); - let m = &m.modulus(); - // Pretty much all the math here requires CPU feature detection to have // been done. `cpu_features` isn't threaded through all the internal // functions, so just make it clear that it has been done at this point. @@ -661,11 +625,7 @@ pub fn elem_exp_consttime( // All entries in `table` will be Montgomery encoded. // acc = table[0] = base**0 (i.e. 1). - // `acc` was initialized to zero and hasn't changed. Change it to 1 and then Montgomery - // encode it. - debug_assert!(acc.iter().all(|&value| value == 0)); - acc[0] = 1; - limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features); + m.oneR(acc); scatter(table, acc, 0, num_limbs); // acc = base**1 (i.e. base). @@ -836,7 +796,7 @@ mod tests { .expect("valid exponent") }; let base = into_encoded(base, &m_); - let actual_result = elem_exp_consttime(base, &e, &m_).unwrap(); + let actual_result = elem_exp_consttime(base, &e, &m).unwrap(); assert_elem_eq(&actual_result, &expected_result); Ok(()) diff --git a/src/arithmetic/bigint/modulus.rs b/src/arithmetic/bigint/modulus.rs index 4eb8e8b91b..807cb70c09 100644 --- a/src/arithmetic/bigint/modulus.rs +++ b/src/arithmetic/bigint/modulus.rs @@ -225,6 +225,36 @@ pub struct Modulus<'a, M> { } impl Modulus<'_, M> { + pub(super) fn oneR(&self, out: &mut [Limb]) { + assert_eq!(self.limbs.len(), out.len()); + + let r = self.limbs.len() * LIMB_BITS; + + // out = 2**r - m where m = self. + limb::limbs_negative_odd(out, self.limbs); + + let lg_m = self.len_bits().as_usize_bits(); + let leading_zero_bits_in_m = r - lg_m; + + // When m's length is a multiple of LIMB_BITS, which is the case we + // most want to optimize for, then we already have + // out == 2**r - m == 2**r (mod m). + if leading_zero_bits_in_m != 0 { + debug_assert!(leading_zero_bits_in_m < LIMB_BITS); + // Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped + // all the leading zero bits to ones. Flip them back. + *out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m; + + // Now we have out == 2**(lg m) (mod m). Keep doubling until we get + // to 2**r (mod m). + for _ in 0..leading_zero_bits_in_m { + limb::limbs_double_mod(out, self.limbs) + } + } + + // Now out == 2**r (mod m) == 1*R. + } + // TODO: XXX Avoid duplication with `Modulus`. pub(super) fn zero(&self) -> Elem { Elem { diff --git a/src/limb.rs b/src/limb.rs index 8dd53099ea..ee139d6262 100644 --- a/src/limb.rs +++ b/src/limb.rs @@ -350,6 +350,17 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) { unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) } } +// r *= 2 (mod m). +pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) { + assert_eq!(r.len(), m.len()); + prefixed_extern! { + fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t); + } + unsafe { + LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len()); + } +} + // *r = -a, assuming a is odd. pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) { debug_assert_eq!(r.len(), a.len()); diff --git a/src/rsa/keypair.rs b/src/rsa/keypair.rs index 819c6678eb..f9a9e16eb4 100644 --- a/src/rsa/keypair.rs +++ b/src/rsa/keypair.rs @@ -468,7 +468,7 @@ fn elem_exp_consttime( // in the Smooth CRT-RSA paper. let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m); - bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.modulus) + bigint::elem_exp_consttime(c_mod_m, &p.exponent, m) } // Type-level representations of the different moduli used in RSA signing, in