Skip to content

Commit

Permalink
Simplify DPF
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Oct 25, 2023
1 parent b777d3c commit db1cde9
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 176 deletions.
10 changes: 5 additions & 5 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use mastic::{
};

use futures::try_join;
use prio::field::Field64;
use prio::field::{Field64, FieldElement};
use rand::{distributions::Alphanumeric, Rng};
use rayon::prelude::*;
use std::{
Expand All @@ -17,7 +17,7 @@ use std::{
};
use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode};

type Key = dpf::DPFKey<Field64, Field64>;
type Key = dpf::DPFKey<Field64>;
type Client = CollectorClient;

fn long_context() -> context::Context {
Expand All @@ -39,7 +39,7 @@ fn sample_string(len: usize) -> String {
fn generate_keys(cfg: &config::Config) -> (Vec<Key>, Vec<Key>) {
let (keys_0, keys_1): (Vec<Key>, Vec<Key>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| dpf::DPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8)))
.map(|_| dpf::DPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::one()))
.unzip();

let encoded: Vec<u8> = bincode::serialize(&keys_0[0]).unwrap();
Expand Down Expand Up @@ -70,8 +70,8 @@ async fn add_keys(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
keys_0: &[dpf::DPFKey<Field64, Field64>],
keys_1: &[dpf::DPFKey<Field64, Field64>],
keys_0: &[dpf::DPFKey<Field64>],
keys_1: &[dpf::DPFKey<Field64>],
num_clients: usize,
malicious_percentage: f32,
) -> io::Result<()> {
Expand Down
6 changes: 3 additions & 3 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ unsafe impl<T> Sync for TreeNode<T> {}
pub struct KeyCollection<T> {
server_id: i8,
depth: usize,
pub keys: Vec<(bool, dpf::DPFKey<T, T>)>,
pub keys: Vec<(bool, dpf::DPFKey<T>)>,
honest_clients: Vec<bool>,
frontier: Vec<TreeNode<T>>,
prev_frontier: Vec<TreeNode<T>>,
Expand Down Expand Up @@ -71,7 +71,7 @@ where
}
}

pub fn add_key(&mut self, key: dpf::DPFKey<T, T>) {
pub fn add_key(&mut self, key: dpf::DPFKey<T>) {
self.keys.push((true, key));
self.honest_clients.push(true);
}
Expand Down Expand Up @@ -133,7 +133,7 @@ where
.keys
.par_iter()
.enumerate()
.map(|(i, key)| key.1.eval_bit_last(&parent.key_states[i], dir, &bit_str))
.map(|(i, key)| key.1.eval_bit(&parent.key_states[i], dir, &bit_str))
.unzip();

let mut child_val = T::zero();
Expand Down
97 changes: 12 additions & 85 deletions src/dpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@ struct CorWord<T> {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DPFKey<T, U> {
pub struct DPFKey<T> {
pub key_idx: bool,
root_seed: prg::PrgSeed,
cor_words: Vec<CorWord<T>>,
cor_word_last: CorWord<U>,
pub cs: Vec<Vec<u8>>,
}

Expand Down Expand Up @@ -139,12 +138,11 @@ where
}

/// All-prefix DPF implementation.
impl<T, U> DPFKey<T, U>
impl<T> DPFKey<T>
where
T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug,
U: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug,
{
pub fn gen(alpha_bits: &[bool], values: &[T], value_last: &U) -> (DPFKey<T, U>, DPFKey<T, U>) {
pub fn gen(alpha_bits: &[bool], values: &[T]) -> (DPFKey<T>, DPFKey<T>) {
debug_assert!(alpha_bits.len() == values.len() + 1);

let root_seeds = (prg::PrgSeed::random(), prg::PrgSeed::random());
Expand All @@ -157,7 +155,7 @@ where
let mut cor_words: Vec<CorWord<T>> = Vec::new();
let mut cs: Vec<Vec<u8>> = Vec::new();
let mut bit_str: String = "".to_string();
for i in 0..(alpha_bits.len() - 1) {
for i in 0..alpha_bits.len() {
let bit = alpha_bits[i];
bit_str.push_str(if bit { "1" } else { "0" });
let cw = gen_cor_word::<T>(bit, values[i], &mut bits, &mut seeds);
Expand All @@ -176,35 +174,17 @@ where
cs.push(crate::xor_vec(&pi_0, &pi_1));
}

let bit = alpha_bits[values.len()];
bit_str.push_str(if bit { "1" } else { "0" });
let last_cw = gen_cor_word::<U>(bit, *value_last, &mut bits, &mut seeds);

let pi_0 = {
hasher.update(&bit_str);
hasher.update(seeds.0.key);
hasher.finalize_reset().to_vec()
};
let pi_1 = {
hasher.update(&bit_str);
hasher.update(seeds.1.key);
hasher.finalize_reset().to_vec()
};
cs.push(crate::xor_vec(&pi_0, &pi_1));

(
DPFKey::<T, U> {
DPFKey::<T> {
key_idx: false,
root_seed: root_seeds.0,
cor_words: cor_words.clone(),
cor_word_last: last_cw.clone(),
cs: cs.clone(),
},
DPFKey::<T, U> {
DPFKey::<T> {
key_idx: true,
root_seed: root_seeds.1,
cor_words,
cor_word_last: last_cw,
cs,
},
)
Expand Down Expand Up @@ -264,60 +244,6 @@ where
)
}

pub fn eval_bit_last(&self, state: &EvalState, dir: bool, bit_str: &String) -> (EvalState, U) {
let tau = state.seed.expand_dir(!dir, dir);
let mut seed = tau.seeds.get(dir).clone();
let mut new_bit = *tau.bits.get(dir);

if state.bit {
seed = &seed ^ &self.cor_word_last.seed;
new_bit ^= self.cor_word_last.bits.get(dir);
}

let converted = seed.convert::<U>();
let new_seed = converted.seed;

let mut word = converted.word;
if new_bit {
word.add_assign(self.cor_word_last.word);
}

if self.key_idx {
word = word.neg();
}

// Compute proofs
let mut hasher = Sha256::new();
let pi_prime = {
hasher.update(bit_str);
hasher.update(new_seed.key);
hasher.finalize_reset().to_vec()
};
let h2 = {
let h: [u8; 32] = if !new_bit {
// H(pi ^ pi_prime)
xor_vec(&state.proof, &pi_prime).try_into().unwrap()
} else {
// H(pi ^ pi_prime ^ cs)
xor_three_vecs(&state.proof, &pi_prime, &self.cs[state.level])
.try_into()
.unwrap()
};
hasher.update(h);
&hasher.finalize_reset()
};

(
EvalState {
level: state.level + 1,
seed,
bit: new_bit,
proof: xor_vec(h2, &state.proof),
},
word,
)
}

pub fn eval_init(&self) -> EvalState {
EvalState {
level: 0,
Expand All @@ -327,7 +253,7 @@ where
}
}

pub fn eval(&self, idx: &[bool], pi: &mut Vec<u8>) -> (Vec<T>, U) {
pub fn eval(&self, idx: &[bool], pi: &mut Vec<u8>) -> (Vec<T>, T) {
debug_assert!(idx.len() <= self.domain_size());
debug_assert!(!idx.is_empty());
let mut out = vec![];
Expand All @@ -344,16 +270,17 @@ where
state = state_new;
}

let (_, last) = self.eval_bit_last(&state, *idx.last().unwrap(), &bit_str);
let (_, last) = self.eval_bit(&state, *idx.last().unwrap(), &bit_str);
*pi = state.proof;

(out, last)
}

pub fn gen_from_str(s: &str) -> (Self, Self) {
pub fn gen_from_str(s: &str, beta: T) -> (Self, Self) {
let bits = crate::string_to_bits(s);
let values = vec![T::one(); bits.len() - 1];
DPFKey::gen(&bits, &values, &U::one())
let values = vec![beta; bits.len()];

DPFKey::gen(&bits, &values)
}

pub fn domain_size(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct ResetRequest {}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AddKeysRequest {
pub keys: Vec<dpf::DPFKey<Field64, Field64>>,
pub keys: Vec<dpf::DPFKey<Field64>>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down
Loading

0 comments on commit db1cde9

Please sign in to comment.