Skip to content

Commit

Permalink
Receive verifier shares in RPC batches
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Nov 1, 2023
1 parent fee03a6 commit eaa6856
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 76 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ system. An example of one such file is in `src/bin/config.json`. The contents of
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"unique_buckets": 10,
"add_key_batch_size": 1000,
"flp_batch_size": 100000,
"unique_buckets": 1000,
"zipf_exponent": 1.03
}
```
Expand All @@ -72,9 +73,10 @@ The parameters are:
clients hold.
* `server0`, `server1`, and `server2`: The `IP:port` of tuple for the two servers. The servers can
run on different IP addresses, but these IPs must be publicly addressable.
* `addkey_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
* `add_key_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
library has an annoying limit on the size of each RPC request, so you cannot set these values too
large.
* `flp_batch_size`: Similar to `add_key_batch_size` but with a greater threshold.
* `unique_buckets` and `zipf_exponent`: Each simulated client samples its private string from a Zipf
distribution over strings with parameter `zipf_exponent` and support `unique_buckets`.

Expand Down
5 changes: 3 additions & 2 deletions src/bin/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"unique_buckets": 10,
"add_key_batch_size": 1000,
"flp_batch_size": 100000,
"unique_buckets": 1000,
"zipf_exponent": 1.03
}
151 changes: 89 additions & 62 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ async fn reset_servers(
let req = ResetRequest {
verify_key: *verify_key,
};
let response_0 = client_0.reset(long_context(), req.clone());
let response_1 = client_1.reset(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.reset(long_context(), req.clone());
let resp_1 = client_1.reset(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = TreeInitRequest {};
let response_0 = client_0.tree_init(long_context(), req.clone());
let response_1 = client_1.tree_init(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_init(long_context(), req.clone());
let resp_1 = client_1.tree_init(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand Down Expand Up @@ -141,55 +141,70 @@ async fn add_keys(
flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone());
}

let response_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let response_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(resp_0, resp_1).unwrap();

let response_0 = client_0.add_all_flp_proof_shares(
let resp_0 = client_0.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_0,
},
);
let response_1 = client_1.add_all_flp_proof_shares(
let resp_1 = client_1.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_1,
},
);
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn run_flp_queries(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = RunFlpQueriesRequest {};
let response_0 = client_0.run_flp_queries(long_context(), req.clone());
let response_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(response_0, response_1).unwrap();

assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

async fn run_flp_queries(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
num_clients: usize,
) -> io::Result<()> {
// Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages.
let count = Count::new();
let keep = flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
let mut keep = vec![];
let mut start = 0;
while start < num_clients {
let end = std::cmp::min(num_clients, start + cfg.flp_batch_size);

let req = RunFlpQueriesRequest { start, end };
let resp_0 = client_0.run_flp_queries(long_context(), req.clone());
let resp_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(resp_0, resp_1).unwrap();
debug_assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

keep.extend(
flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>(),
);

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>();
start += cfg.flp_batch_size;
}

// Tree prune
let req = ApplyFLPResultsRequest { keep };
let response_0 = client_0.apply_flp_results(long_context(), req.clone());
let response_1 = client_1.apply_flp_results(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.apply_flp_results(long_context(), req.clone());
let resp_1 = client_1.apply_flp_results(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -213,10 +228,10 @@ async fn run_level(
is_last,
};

let response_0 = client_0.tree_crawl(long_context(), req.clone());
let response_1 = client_1.tree_crawl(long_context(), req);
let resp_0 = client_0.tree_crawl(long_context(), req.clone());
let resp_1 = client_1.tree_crawl(long_context(), req);
let ((cnt_values_0, mt_root_0, indices_0), (cnt_values_1, mt_root_1, indices_1)) =
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
keep =
Expand Down Expand Up @@ -254,9 +269,9 @@ async fn run_level(

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -269,34 +284,46 @@ async fn run_level_last(
) -> io::Result<()> {
let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64);

let req = TreeCrawlLastRequest {};
let response_0 = client_0.tree_crawl_last(long_context(), req.clone());
let response_1 = client_1.tree_crawl_last(long_context(), req);
let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) =
try_join!(response_0, response_1).unwrap();
// Receive counters in chunks to avoid having huge RPC messages.
let mut keep = vec![];
let mut start = 0;
while start < num_clients {
let end = std::cmp::min(num_clients, start + cfg.flp_batch_size);

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
assert_eq!(hashes_0.len(), hashes_1.len());
let req = TreeCrawlLastRequest { start, end };
let resp_0 = client_0.tree_crawl_last(long_context(), req.clone());
let resp_1 = client_1.tree_crawl_last(long_context(), req);
let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) =
try_join!(resp_0, resp_1).unwrap();

let verified = hashes_0
.par_iter()
.zip(hashes_1.par_iter())
.all(|(&h0, &h1)| h0 == h1);
assert!(verified);
assert_eq!(cnt_values_0.len(), cnt_values_1.len());
assert_eq!(hashes_0.len(), hashes_1.len());

let keep =
collect::KeyCollection::<Field64>::keep_values(threshold, &cnt_values_0, &cnt_values_1);
let verified = hashes_0
.par_iter()
.zip(hashes_1.par_iter())
.all(|(&h0, &h1)| h0 == h1);
assert!(verified);

keep.extend(collect::KeyCollection::<Field64>::keep_values(
threshold,
&cnt_values_0,
&cnt_values_1,
));

start += cfg.flp_batch_size;
}

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

let req = FinalSharesRequest {};
let response_0 = client_0.final_shares(long_context(), req.clone());
let response_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.final_shares(long_context(), req.clone());
let resp_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap();
for res in &collect::KeyCollection::<Field64>::final_values(&shares_0, &shares_1) {
let bits = mastic::bits_to_bitstring(&res.path);
if res.value > Field64::from(0) {
Expand Down Expand Up @@ -348,7 +375,7 @@ async fn main() -> io::Result<()> {
let mut responses = vec![];

for _ in 0..reqs_in_flight {
let this_batch = std::cmp::min(left_to_go, cfg.addkey_batch_size);
let this_batch = std::cmp::min(left_to_go, cfg.add_key_batch_size);
left_to_go -= this_batch;

if this_batch > 0 {
Expand All @@ -371,7 +398,7 @@ async fn main() -> io::Result<()> {
for level in 0..bit_len - 1 {
let start_level = Instant::now();
if level == 0 {
run_flp_queries(&client_0, &client_1).await?;
run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?;
}
run_level(&cfg, &client_0, &client_1, num_clients).await?;
println!(
Expand Down
7 changes: 5 additions & 2 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,14 @@ impl Collector for CollectorServer {
async fn run_flp_queries(
self,
_: context::Context,
_req: RunFlpQueriesRequest,
req: RunFlpQueriesRequest,
) -> Vec<Vec<Field64>> {
let mut coll = self.arc.lock().unwrap();
let start = req.start;
let end = req.end;
assert!(start < end);

coll.run_flp_queries()
coll.run_flp_queries(start, end)
}

async fn apply_flp_results(self, _: context::Context, req: ApplyFLPResultsRequest) -> String {
Expand Down
3 changes: 2 additions & 1 deletion src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ where
child
}

pub fn run_flp_queries(&mut self) -> Vec<Vec<T>> {
pub fn run_flp_queries(&mut self, start: usize, end: usize) -> Vec<Vec<T>> {
let level = self.frontier[0].path.len();
assert_eq!(level, 0);

Expand All @@ -156,6 +156,7 @@ where
.key_values
.par_iter()
.enumerate()
.filter(|(client_index, _)| *client_index >= start && *client_index < end)
.map(|(client_index, _)| {
let y_p0 = node_left.key_values[client_index];
let y_p1 = node_right.key_values[client_index];
Expand Down
13 changes: 9 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use serde_json::Value;

pub struct Config {
pub data_bytes: usize,
pub addkey_batch_size: usize,
pub add_key_batch_size: usize,
pub flp_batch_size: usize,
pub unique_buckets: usize,
pub threshold: f64,
pub zipf_exponent: f64,
Expand All @@ -22,9 +23,12 @@ pub fn get_config(filename: &str) -> Config {
let v: Value = serde_json::from_str(json_data).expect("Cannot parse JSON config");

let data_bytes: usize = v["data_bytes"].as_u64().expect("Can't parse data_bytes") as usize;
let addkey_batch_size: usize = v["addkey_batch_size"]
let add_key_batch_size: usize = v["add_key_batch_size"]
.as_u64()
.expect("Can't parse addkey_batch_size") as usize;
.expect("Can't parse add_key_batch_size") as usize;
let flp_batch_size: usize = v["flp_batch_size"]
.as_u64()
.expect("Can't parse flp_batch_size") as usize;
let unique_buckets: usize = v["unique_buckets"]
.as_u64()
.expect("Can't parse unique_buckets") as usize;
Expand All @@ -37,7 +41,8 @@ pub fn get_config(filename: &str) -> Config {

Config {
data_bytes,
addkey_batch_size,
add_key_batch_size,
flp_batch_size,
unique_buckets,
threshold,
zipf_exponent,
Expand Down
10 changes: 8 additions & 2 deletions src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,16 @@ pub struct TreeCrawlRequest {
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TreeCrawlLastRequest {}
pub struct TreeCrawlLastRequest {
pub start: usize,
pub end: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RunFlpQueriesRequest {}
pub struct RunFlpQueriesRequest {
pub start: usize,
pub end: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TreePruneRequest {
Expand Down

0 comments on commit eaa6856

Please sign in to comment.