Skip to content

Commit

Permalink
VIDPF restructure
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Dec 7, 2023
1 parent d891bd6 commit 8174f79
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 44 deletions.
14 changes: 8 additions & 6 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@ use std::{

use futures::try_join;
use mastic::{
collect, config, dpf,
collect, config,
rpc::{
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, FinalSharesRequest,
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
CollectorClient,
vidpf, CollectorClient,
};
use prio::{
field::{random_vector, Field64},
Expand All @@ -21,7 +21,7 @@ use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rayon::prelude::*;
use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode};

type Key = dpf::DPFKey<Field64>;
type Key = vidpf::VIDPFKey<Field64>;
type Client = CollectorClient;

fn long_context() -> context::Context {
Expand Down Expand Up @@ -49,7 +49,9 @@ fn generate_keys(

let (keys_0, keys_1): (Vec<Key>, Vec<Key>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| dpf::DPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::from(beta)))
.map(|_| {
vidpf::VIDPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::from(beta))
})
.unzip();

let (proofs_0, proofs_1): (Vec<Vec<Field64>>, Vec<Vec<Field64>>) = rayon::iter::repeat(0)
Expand Down Expand Up @@ -105,8 +107,8 @@ async fn add_keys(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
keys_0: &[dpf::DPFKey<Field64>],
keys_1: &[dpf::DPFKey<Field64>],
keys_0: &[vidpf::VIDPFKey<Field64>],
keys_1: &[vidpf::VIDPFKey<Field64>],
proofs_0: &[Vec<Field64>],
proofs_1: &[Vec<Field64>],
num_clients: usize,
Expand Down
10 changes: 5 additions & 5 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use rayon::prelude::*;
use rs_merkle::{Hasher, MerkleTree};
use serde::{Deserialize, Serialize};

use crate::{dpf, prg, xor_in_place, xor_vec, HASH_SIZE};
use crate::{prg, vidpf, xor_in_place, xor_vec, HASH_SIZE};

#[derive(Clone)]
pub struct HashAlg {}
Expand All @@ -24,7 +24,7 @@ impl Hasher for HashAlg {
struct TreeNode<T> {
path: Vec<bool>,
value: T,
key_states: Vec<dpf::EvalState>,
key_states: Vec<vidpf::EvalState>,
key_values: Vec<T>,
}

Expand All @@ -36,7 +36,7 @@ pub struct KeyCollection<T> {
server_id: i8,
verify_key: [u8; 16],
depth: usize,
pub keys: Vec<(bool, dpf::DPFKey<T>)>,
pub keys: Vec<(bool, vidpf::VIDPFKey<T>)>,
nonces: Vec<[u8; 16]>,
all_flp_proof_shares: Vec<Vec<T>>,
frontier: Vec<TreeNode<T>>,
Expand Down Expand Up @@ -83,7 +83,7 @@ where
}
}

pub fn add_key(&mut self, key: dpf::DPFKey<T>) {
pub fn add_key(&mut self, key: vidpf::VIDPFKey<T>) {
self.keys.push((true, key));
}

Expand Down Expand Up @@ -113,7 +113,7 @@ where
let mut bit_str = crate::bits_to_bitstring(parent.path.as_slice());
bit_str.push(if dir { '1' } else { '0' });

let (key_states, key_values): (Vec<dpf::EvalState>, Vec<T>) = self
let (key_states, key_values): (Vec<vidpf::EvalState>, Vec<T>) = self
.keys
.par_iter()
.enumerate()
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod collect;
pub mod config;
pub mod dpf;
pub mod prg;
pub mod rpc;
pub mod vidpf;

extern crate lazy_static;

Expand Down
4 changes: 2 additions & 2 deletions src/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use prio::field::Field64;
use serde::{Deserialize, Serialize};

use crate::{collect, dpf, HASH_SIZE};
use crate::{collect, vidpf, HASH_SIZE};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResetRequest {
Expand All @@ -10,7 +10,7 @@ pub struct ResetRequest {

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AddKeysRequest {
pub keys: Vec<dpf::DPFKey<Field64>>,
pub keys: Vec<vidpf::VIDPFKey<Field64>>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down
25 changes: 10 additions & 15 deletions src/dpf.rs → src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct CorWord<T> {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DPFKey<T> {
pub struct VIDPFKey<T> {
pub key_idx: bool,
root_seed: prg::PrgSeed,
cor_words: Vec<CorWord<T>>,
Expand Down Expand Up @@ -80,7 +80,7 @@ impl<T> TupleExt<T> for (T, T) {

fn gen_cor_word<W>(
bit: bool,
value: W,
beta: W,
bits: &mut (bool, bool),
seeds: &mut (prg::PrgSeed, prg::PrgSeed),
) -> CorWord<W>
Expand Down Expand Up @@ -121,7 +121,7 @@ where
}

let converted = seeds.map(|s| s.convert());
cw.word = value;
cw.word = beta;
cw.word.sub_assign(converted.0.word);
cw.word.add_assign(converted.1.word);

Expand All @@ -136,13 +136,11 @@ where
}

/// All-prefix DPF implementation.
impl<T> DPFKey<T>
impl<T> VIDPFKey<T>
where
T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug,
{
pub fn gen(alpha_bits: &[bool], values: &[T]) -> (DPFKey<T>, DPFKey<T>) {
debug_assert!(alpha_bits.len() == values.len() + 1);

pub fn gen(alpha_bits: &[bool], beta: T) -> (VIDPFKey<T>, VIDPFKey<T>) {
let root_seeds = (prg::PrgSeed::random(), prg::PrgSeed::random());
let root_bits = (false, true);

Expand All @@ -153,10 +151,9 @@ where
let mut cor_words: Vec<CorWord<T>> = Vec::new();
let mut cs: Vec<[u8; HASH_SIZE]> = Vec::new();
let mut bit_str: String = "".to_string();
for i in 0..alpha_bits.len() {
let bit = alpha_bits[i];
for &bit in alpha_bits {
bit_str.push_str(if bit { "1" } else { "0" });
let cw = gen_cor_word::<T>(bit, values[i], &mut bits, &mut seeds);
let cw = gen_cor_word::<T>(bit, beta, &mut bits, &mut seeds);
cor_words.push(cw);

let pi_0 = {
Expand All @@ -179,13 +176,13 @@ where
}

(
DPFKey::<T> {
VIDPFKey::<T> {
key_idx: false,
root_seed: root_seeds.0,
cor_words: cor_words.clone(),
cs: cs.clone(),
},
DPFKey::<T> {
VIDPFKey::<T> {
key_idx: true,
root_seed: root_seeds.1,
cor_words,
Expand Down Expand Up @@ -291,9 +288,7 @@ where

pub fn gen_from_str(s: &str, beta: T) -> (Self, Self) {
let bits = crate::string_to_bits(s);
let values = vec![beta; bits.len()];

DPFKey::gen(&bits, &values)
VIDPFKey::gen(&bits, beta)
}

pub fn domain_size(&self) -> usize {
Expand Down
4 changes: 2 additions & 2 deletions tests/collect_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ fn collect_test_eval_groups() {
let mut col_1 = KeyCollection::new(1, &seed, strlen, verify_key);

for cstr in &client_strings {
let (keys_0, keys_1) = dpf::DPFKey::<Field64>::gen_from_str(&cstr, Field64::one());
let (keys_0, keys_1) = vidpf::VIDPFKey::<Field64>::gen_from_str(&cstr, Field64::one());
col_0.add_key(keys_0);
col_1.add_key(keys_1);
}
Expand Down Expand Up @@ -101,7 +101,7 @@ fn collect_test_eval_full_groups() {
let mut keys = vec![];
println!("Starting to generate keys");
for s in &client_strings {
keys.push(dpf::DPFKey::<Field64>::gen_from_str(&s, Field64::one()));
keys.push(vidpf::VIDPFKey::<Field64>::gen_from_str(&s, Field64::one()));
}
println!("Done generating keys");

Expand Down
18 changes: 5 additions & 13 deletions tests/dpf_test.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
use std::ops::Add;

use blake3::hash;
use mastic::{dpf::*, *};
use mastic::{vidpf::*, *};
use prio::field::Field64;

#[test]
fn dpf_complete() {
let num_bits = 5;
let alpha = u32_to_bits(num_bits, 21);
let betas = vec![
Field64::from(7u64),
Field64::from(17u64),
Field64::from(2u64),
Field64::from(0u64),
Field64::from(32u64),
];
let (key_0, key_1) = DPFKey::gen(&alpha, &betas);
let beta = Field64::from(7u64);
let (key_0, key_1) = VIDPFKey::gen(&alpha, beta);

let mut pi_0: [u8; HASH_SIZE] = hash(b"0").as_bytes()[0..HASH_SIZE].try_into().unwrap();
let mut pi_1: [u8; HASH_SIZE] = pi_0.clone();
Expand All @@ -32,11 +26,9 @@ fn dpf_complete() {
println!("[{:?}] Tmp {:?} = {:?}", alpha_eval, j, tmp);
if alpha[0..j - 1] == alpha_eval[0..j - 1] {
assert_eq!(
betas[j - 2],
tmp,
beta, tmp,
"[Level {:?}] Value incorrect at {:?}",
j,
alpha_eval
j, alpha_eval
);
} else {
assert_eq!(Field64::from(0), tmp);
Expand Down

0 comments on commit 8174f79

Please sign in to comment.