Skip to content

Commit

Permalink
Add basic FLPs.
Browse files Browse the repository at this point in the history
TODO: joint randomness, nonce, and some optimizations
  • Loading branch information
jimouris committed Oct 30, 2023
1 parent db1cde9 commit e6a3674
Show file tree
Hide file tree
Showing 12 changed files with 326 additions and 129 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bincode = "1.3.3"
clap = "2.0"
ctr = "0.9.2"
futures = "0.3.28"
getrandom = { version = "0.2.10", features = ["std"] }
itertools = "0.10.5"
lazy_static = "1.4"
num = "0.4.0"
Expand Down
8 changes: 8 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Get help on options with `rustfmt --help=config`
# Please keep these in alphabetical order.
edition = "2021"
group_imports = "StdExternalCrate"
imports_granularity = "Crate"
merge_derives = false
use_field_init_shorthand = true
version = "Two"
149 changes: 124 additions & 25 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
use std::{
io,
time::{Duration, Instant, SystemTime},
};

use futures::try_join;
use mastic::{
collect, config, dpf,
rpc::{
AddKeysRequest, FinalSharesRequest, ResetRequest, TreeCrawlLastRequest, TreeCrawlRequest,
TreeInitRequest, TreePruneRequest,
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, FinalSharesRequest, ResetRequest,
RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest,
TreePruneRequest,
},
CollectorClient,
};

use futures::try_join;
use prio::field::{Field64, FieldElement};
use rand::{distributions::Alphanumeric, Rng};
use rayon::prelude::*;
use std::{
io,
time::{Duration, Instant, SystemTime},
use prio::{
field::{random_vector, Field64},
flp::{types::Count, Type},
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rayon::prelude::*;
use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode};

type Key = dpf::DPFKey<Field64>;
Expand All @@ -36,20 +40,51 @@ fn sample_string(len: usize) -> String {
.collect()
}

fn generate_keys(cfg: &config::Config) -> (Vec<Key>, Vec<Key>) {
fn generate_keys(
cfg: &config::Config,
) -> ((Vec<Key>, Vec<Key>), (Vec<Vec<Field64>>, Vec<Vec<Field64>>)) {
let beta = 1u64;
let count = Count::new();
let input_beta: Vec<Field64> = count.encode_measurement(&beta).unwrap();

let (keys_0, keys_1): (Vec<Key>, Vec<Key>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| dpf::DPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::one()))
.map(|_| dpf::DPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::from(beta)))
.unzip();

let (proofs_0, proofs_1): (Vec<Vec<Field64>>, Vec<Vec<Field64>>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| {
let prove_rand = random_vector(count.prove_rand_len()).unwrap();
let proof = count.prove(&input_beta, &prove_rand, &[]).unwrap();

let proof_0 = proof
.iter()
.map(|_| Field64::from(rand::thread_rng().gen::<u64>()))
.collect::<Vec<_>>();
let proof_1 = proof
.par_iter()
.zip(proof_0.par_iter())
.map(|(p_0, p_1)| p_0 - p_1)
.collect::<Vec<_>>();
(proof_0, proof_1)
})
.unzip();

let encoded: Vec<u8> = bincode::serialize(&keys_0[0]).unwrap();
println!("Key size: {:?} bytes", encoded.len());

(keys_0, keys_1)
((keys_0, keys_1), (proofs_0, proofs_1))
}

async fn reset_servers(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = ResetRequest {};
async fn reset_servers(
client_0: &Client,
client_1: &Client,
verify_key: &[u8; 16],
) -> io::Result<()> {
let req = ResetRequest {
verify_key: *verify_key,
};
let response_0 = client_0.reset(long_context(), req.clone());
let response_1 = client_1.reset(long_context(), req);
try_join!(response_0, response_1).unwrap();
Expand All @@ -72,6 +107,8 @@ async fn add_keys(
client_1: &Client,
keys_0: &[dpf::DPFKey<Field64>],
keys_1: &[dpf::DPFKey<Field64>],
proofs_0: &[Vec<Field64>],
proofs_1: &[Vec<Field64>],
num_clients: usize,
malicious_percentage: f32,
) -> io::Result<()> {
Expand All @@ -81,23 +118,79 @@ async fn add_keys(

let mut add_keys_0 = Vec::with_capacity(num_clients);
let mut add_keys_1 = Vec::with_capacity(num_clients);

let mut flp_proof_shares_0 = Vec::with_capacity(num_clients);
let mut flp_proof_shares_1 = Vec::with_capacity(num_clients);
for r in 0..num_clients {
let idx_1 = zipf.sample(&mut rng) - 1;
let mut idx_2 = idx_1;
if rand::thread_rng().gen_range(0.0..1.0) < malicious_percentage {
idx_2 += 1;
let mut idx_3 = idx_1;
if rng.gen_range(0.0..1.0) < malicious_percentage {
if rng.gen() {
// Malicious key.
idx_2 += 1;
} else {
// Malicious FLP.
idx_3 += 1;
}
println!("Malicious {}", r);
}

add_keys_0.push(keys_0[idx_1].clone());
add_keys_1.push(keys_1[idx_2 % cfg.unique_buckets].clone());

flp_proof_shares_0.push(proofs_0[idx_1].clone());
flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone());
}

let response_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let response_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(response_0, response_1).unwrap();

let response_0 = client_0.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_0,
},
);
let response_1 = client_1.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_1,
},
);
try_join!(response_0, response_1).unwrap();

Ok(())
}

async fn run_flp_queries(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = RunFlpQueriesRequest {};
let response_0 = client_0.run_flp_queries(long_context(), req.clone());
let response_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(response_0, response_1).unwrap();

assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

let count = Count::new();
let keep = flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>();

// Tree prune
let req = ApplyFLPResultsRequest { keep };
let response_0 = client_0.apply_flp_results(long_context(), req.clone());
let response_1 = client_1.apply_flp_results(long_context(), req);
try_join!(response_0, response_1).unwrap();

Ok(())
}

Expand Down Expand Up @@ -128,7 +221,6 @@ async fn run_level(
assert_eq!(cnt_values_0.len(), cnt_values_1.len());
keep =
collect::KeyCollection::<Field64>::keep_values(threshold, &cnt_values_0, &cnt_values_1);
println!("mt_root_0.len() {}", mt_root_0.len());
if mt_root_0.is_empty() {
break;
}
Expand Down Expand Up @@ -236,7 +328,7 @@ async fn main() -> io::Result<()> {

let start = Instant::now();
println!("Generating keys...");
let (keys_0, keys_1) = generate_keys(&cfg);
let ((keys_0, keys_1), (proofs_0, proofs_1)) = generate_keys(&cfg);
let delta = start.elapsed().as_secs_f64();
println!(
"Generated {:?} keys in {:?} seconds ({:?} sec/key)",
Expand All @@ -245,7 +337,10 @@ async fn main() -> io::Result<()> {
delta / (keys_0.len() as f64)
);

reset_servers(&client_0, &client_1).await?;
let mut verify_key = [0; 16];
thread_rng().fill(&mut verify_key);

reset_servers(&client_0, &client_1, &verify_key).await?;

let mut left_to_go = num_clients;
let reqs_in_flight = 1000;
Expand All @@ -258,7 +353,8 @@ async fn main() -> io::Result<()> {

if this_batch > 0 {
responses.push(add_keys(
&cfg, &client_0, &client_1, &keys_0, &keys_1, this_batch, malicious,
&cfg, &client_0, &client_1, &keys_0, &keys_1, &proofs_0, &proofs_1, this_batch,
malicious,
));
}
}
Expand All @@ -272,12 +368,15 @@ async fn main() -> io::Result<()> {

let start = Instant::now();
let bit_len = cfg.data_bytes * 8; // bits
for _level in 0..bit_len - 1 {
for level in 0..bit_len - 1 {
let start_level = Instant::now();
if level == 0 {
run_flp_queries(&client_0, &client_1).await?;
}
run_level(&cfg, &client_0, &client_1, num_clients).await?;
println!(
"Time for level {}: {:?}",
_level,
level,
start_level.elapsed().as_secs_f64()
);
}
Expand Down
62 changes: 48 additions & 14 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
use std::{
io,
sync::{Arc, Mutex},
time::Instant,
};

use futures::{future, prelude::*};
use mastic::{
collect, config, prg,
rpc::{
AddKeysRequest, Collector, FinalSharesRequest, ResetRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, Collector, FinalSharesRequest,
ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest,
TreeInitRequest, TreePruneRequest,
},
};

use futures::{future, prelude::*};
use prio::field::Field64;
use std::time::Instant;
use std::{
io,
sync::{Arc, Mutex},
};
use tarpc::{
context,
serde_transport::tcp,
server::{self, Channel},
server::{
Channel, {self},
},
tokio_serde::formats::Bincode,
};

Expand All @@ -30,9 +33,14 @@ struct CollectorServer {

#[tarpc::server]
impl Collector for CollectorServer {
async fn reset(self, _: context::Context, _req: ResetRequest) -> String {
async fn reset(self, _: context::Context, req: ResetRequest) -> String {
let mut coll = self.arc.lock().unwrap();
*coll = collect::KeyCollection::new(self.server_id, &self.seed, self.data_bytes);
*coll = collect::KeyCollection::new(
self.server_id,
&self.seed,
self.data_bytes,
req.verify_key,
);
"Done".to_string()
}

Expand All @@ -47,6 +55,17 @@ impl Collector for CollectorServer {
"Done".to_string()
}

async fn add_all_flp_proof_shares(self, _: context::Context, req: AddFLPsRequest) -> String {
let mut coll = self.arc.lock().unwrap();
for flp_proof_share in req.flp_proof_shares {
coll.add_flp_proof_share(flp_proof_share);
}
if coll.keys.len() % 10000 == 0 {
println!("Number of keys: {:?}", coll.keys.len());
}
"Done".to_string()
}

async fn tree_init(self, _: context::Context, _req: TreeInitRequest) -> String {
let start = Instant::now();
let mut coll = self.arc.lock().unwrap();
Expand All @@ -60,7 +79,6 @@ impl Collector for CollectorServer {
_: context::Context,
req: TreeCrawlRequest,
) -> (Vec<Field64>, Vec<Vec<u8>>, Vec<usize>) {
// let start = Instant::now();
let split_by = req.split_by;
let malicious = req.malicious;
let is_last = req.is_last;
Expand All @@ -69,6 +87,22 @@ impl Collector for CollectorServer {
coll.tree_crawl(split_by, &malicious, is_last)
}

async fn run_flp_queries(
self,
_: context::Context,
_req: RunFlpQueriesRequest,
) -> Vec<Vec<Field64>> {
let mut coll = self.arc.lock().unwrap();

coll.run_flp_queries()
}

async fn apply_flp_results(self, _: context::Context, req: ApplyFLPResultsRequest) -> String {
let mut coll = self.arc.lock().unwrap();
coll.apply_flp_results(&req.keep);
"Done".to_string()
}

async fn tree_crawl_last(
self,
_: context::Context,
Expand Down Expand Up @@ -108,7 +142,7 @@ async fn main() -> io::Result<()> {

let seed = prg::PrgSeed { key: [1u8; 16] };

let coll = collect::KeyCollection::new(server_id, &seed, cfg.data_bytes * 8);
let coll = collect::KeyCollection::new(server_id, &seed, cfg.data_bytes * 8, [0u8; 16]);
let arc = Arc::new(Mutex::new(coll));

println!("Server {} running at {:?}", server_id, server_addr);
Expand Down
Loading

0 comments on commit e6a3674

Please sign in to comment.