Skip to content

Commit

Permalink
bigint: Calculate 1*R mod m without multiplication by 1*RR.
Browse files Browse the repository at this point in the history
Save two private-modulus Montgomery multiplications per RSA exponentiation
at the cost of approximately two modulus-wide XORs.

The new new `oneR()` is extracted from the Montgomery RR setup.

Remove the use of `One<RR>` in `elem_exp_consttime`.
  • Loading branch information
briansmith committed Nov 12, 2023
1 parent 98a78c7 commit b6099b4
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 56 deletions.
68 changes: 13 additions & 55 deletions src/arithmetic/bigint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,7 @@ where

// r *= 2.
fn elem_double<M, AF>(r: &mut Elem<M, AF>, m: &Modulus<M>) {
prefixed_extern! {
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
unsafe {
LIMBS_shl_mod(
r.limbs.as_mut_ptr(),
r.limbs.as_ptr(),
m.limbs().as_ptr(),
m.limbs().len(),
);
}
limb::limbs_double_mod(&mut r.limbs, m.limbs())
}

// TODO: This is currently unused, but we intend to eventually use this to
Expand Down Expand Up @@ -288,29 +278,15 @@ impl<M> One<M, RR> {
fn newRR(m: &Modulus<M>) -> Self {
let r = m.limbs().len() * LIMB_BITS;

// base = 2**r - n.
// base = 2**r (mod n) == R (mod n).
let mut base = m.zero();
limb::limbs_negative_odd(&mut base.limbs, m.limbs());

// Correct base to 2**(lg m) (mod m).
let lg_m = m.len_bits().as_usize_bits();
let leading_zero_bits_in_m = r - lg_m;
if leading_zero_bits_in_m != 0 {
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
// `limbs_negative_odd` flipped all the leading zero bits to ones.
// Flip them back.
*base.limbs.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;
}
m.oneR(&mut base.limbs);

// Double `base` so that base == R == 2**r (mod m). For normal moduli
// that have the high bit of the highest limb set, this requires one
// doubling. Unusual moduli require more doublings but we are less
// concerned about the performance of those.
// Compute RR = R**2 == base**r (mod n).
//
// Then double `base` again so that base == 2*R (mod n), i.e. `2` in
// Montgomery form (`elem_exp_vartime()` requires the base to be in
// Montgomery form). Then compute
// RR = R**2 == base**r == R**r == (2**r)**r (mod n).
// Double `base` so that base == 2*R (mod n), i.e. `2` in Montgomery
// form (`elem_exp_vartime()` requires the base to be in Montgomery
// form).
//
// Take advantage of the fact that `elem_double` is faster than
// `elem_squared` by replacing some of the early squarings with
Expand All @@ -320,8 +296,6 @@ impl<M> One<M, RR> {
const LG_BASE: usize = 2; // doubling vs. squaring trade-off.
const _LG_BASE_IS_POWER_OF_2: () = assert!(LG_BASE.is_power_of_two());

let doublings = leading_zero_bits_in_m + LG_BASE;

// r is divisible by LIMB_BITS and LIMB_BITS is divisible by LG_BASE so
// r is divisible by LG_BASE.
//
Expand All @@ -336,7 +310,7 @@ impl<M> One<M, RR> {
debug_assert_eq!(r % LIMB_BITS, 0);
debug_assert_eq!(r % LG_BASE, 0);
let exponent = NonZeroU64::new(u64_from_usize(r / LG_BASE)).unwrap();
for _ in 0..doublings {
for _ in 0..LG_BASE {
elem_double(&mut base, m)
}
let RR = elem_exp_vartime(base, exponent, m);
Expand Down Expand Up @@ -409,11 +383,8 @@ pub(crate) fn elem_exp_vartime<M>(
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &OwnedModulusWithOne<M>,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
let oneRR = m.oneRR();
let m = &m.modulus();

use crate::{bssl, limb::Window};

const WINDOW_BITS: usize = 5;
Expand Down Expand Up @@ -461,13 +432,7 @@ pub fn elem_exp_consttime<M>(
}

// table[0] = base**0 (i.e. 1).
{
let acc = entry_mut(&mut table, 0, num_limbs);
// `table` was initialized to zero and hasn't changed.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &oneRR.0.limbs, m.limbs(), m.n0(), m.cpu_features());
}
m.oneR(entry_mut(&mut table, 0, num_limbs));

entry_mut(&mut table, 1, num_limbs).copy_from_slice(&base.limbs);
for i in 2..TABLE_ENTRIES {
Expand Down Expand Up @@ -504,13 +469,10 @@ pub fn elem_exp_consttime<M>(
pub fn elem_exp_consttime<M>(
base: Elem<M, R>,
exponent: &PrivateExponent,
m: &OwnedModulusWithOne<M>,
m: &Modulus<M>,
) -> Result<Elem<M, Unencoded>, error::Unspecified> {
use crate::limb::LIMB_BYTES;

let oneRR = m.oneRR();
let m = &m.modulus();

// Pretty much all the math here requires CPU feature detection to have
// been done. `cpu_features` isn't threaded through all the internal
// functions, so just make it clear that it has been done at this point.
Expand Down Expand Up @@ -661,11 +623,7 @@ pub fn elem_exp_consttime<M>(
// All entries in `table` will be Montgomery encoded.

// acc = table[0] = base**0 (i.e. 1).
// `acc` was initialized to zero and hasn't changed. Change it to 1 and then Montgomery
// encode it.
debug_assert!(acc.iter().all(|&value| value == 0));
acc[0] = 1;
limbs_mont_mul(acc, &oneRR.0.limbs, m_cached, n0, cpu_features);
m.oneR(acc);
scatter(table, acc, 0, num_limbs);

// acc = base**1 (i.e. base).
Expand Down Expand Up @@ -836,7 +794,7 @@ mod tests {
.expect("valid exponent")
};
let base = into_encoded(base, &m_);
let actual_result = elem_exp_consttime(base, &e, &m_).unwrap();
let actual_result = elem_exp_consttime(base, &e, &m).unwrap();
assert_elem_eq(&actual_result, &expected_result);

Ok(())
Expand Down
30 changes: 30 additions & 0 deletions src/arithmetic/bigint/modulus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,36 @@ pub struct Modulus<'a, M> {
}

impl<M> Modulus<'_, M> {
pub(super) fn oneR(&self, out: &mut [Limb]) {
assert_eq!(self.limbs.len(), out.len());

let r = self.limbs.len() * LIMB_BITS;

// out = 2**r - m where m = self.
limb::limbs_negative_odd(out, self.limbs);

let lg_m = self.len_bits().as_usize_bits();
let leading_zero_bits_in_m = r - lg_m;

// When m's length is a multiple of LIMB_BITS, which is the case we
// most want to optimize for, then we already have
// out == 2**r - m == 2**r (mod m).
if leading_zero_bits_in_m != 0 {
debug_assert!(leading_zero_bits_in_m < LIMB_BITS);
// Correct out to 2**(lg m) (mod m). `limbs_negative_odd` flipped
// all the leading zero bits to ones. Flip them back.
*out.last_mut().unwrap() &= (!0) >> leading_zero_bits_in_m;

// Now we have out == 2**(lg m) (mod m). Keep doubling until we get
// to 2**r (mod m).
for _ in 0..leading_zero_bits_in_m {
limb::limbs_double_mod(out, self.limbs)
}
}

// Now out == 2**r (mod m) == 1*R.
}

// TODO: XXX Avoid duplication with `Modulus`.
pub(super) fn zero<E>(&self) -> Elem<M, E> {
Elem {
Expand Down
11 changes: 11 additions & 0 deletions src/limb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,17 @@ pub(crate) fn limbs_add_assign_mod(a: &mut [Limb], b: &[Limb], m: &[Limb]) {
unsafe { LIMBS_add_mod(a.as_mut_ptr(), a.as_ptr(), b.as_ptr(), m.as_ptr(), m.len()) }
}

// r *= 2 (mod m).
pub(crate) fn limbs_double_mod(r: &mut [Limb], m: &[Limb]) {
assert_eq!(r.len(), m.len());
prefixed_extern! {
fn LIMBS_shl_mod(r: *mut Limb, a: *const Limb, m: *const Limb, num_limbs: c::size_t);
}
unsafe {
LIMBS_shl_mod(r.as_mut_ptr(), r.as_ptr(), m.as_ptr(), m.len());
}
}

// *r = -a, assuming a is odd.
pub(crate) fn limbs_negative_odd(r: &mut [Limb], a: &[Limb]) {
debug_assert_eq!(r.len(), a.len());
Expand Down
2 changes: 1 addition & 1 deletion src/rsa/keypair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ fn elem_exp_consttime<M>(
// in the Smooth CRT-RSA paper.
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
let c_mod_m = bigint::elem_mul(p.modulus.oneRR().as_ref(), c_mod_m, m);
bigint::elem_exp_consttime(c_mod_m, &p.exponent, &p.modulus)
bigint::elem_exp_consttime(c_mod_m, &p.exponent, m)
}

// Type-level representations of the different moduli used in RSA signing, in
Expand Down

0 comments on commit b6099b4

Please sign in to comment.