From c05c73289a9a14c9f5522975f98ce6822bdc2634 Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Wed, 1 Nov 2023 17:04:32 -0400 Subject: [PATCH] Receive verifier shares in RPC batches --- README.md | 6 ++- src/bin/config.json | 3 +- src/bin/leader.rs | 118 +++++++++++++++++++++++++------------------- src/bin/server.rs | 7 ++- src/collect.rs | 3 +- src/config.rs | 11 +++-- src/rpc.rs | 5 +- 7 files changed, 89 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index b2778f7..63232a7 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,8 @@ system. An example of one such file is in `src/bin/config.json`. The contents of "threshold": 0.01, "server_0": "0.0.0.0:8000", "server_1": "0.0.0.0:8001", - "addkey_batch_size": 100, + "add_key_batch_size": 1000, + "flp_batch_size": 10000, "unique_buckets": 10, "zipf_exponent": 1.03 } @@ -72,9 +73,10 @@ The parameters are: clients hold. * `server0`, `server1`, and `server2`: The `IP:port` of tuple for the two servers. The servers can run on different IP addresses, but these IPs must be publicly addressable. -* `addkey_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC +* `add_key_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC library has an annoying limit on the size of each RPC request, so you cannot set these values too large. +* `flp_batch_size`: Similar to `add_key_batch_size` but with a greater threshold. * `unique_buckets` and `zipf_exponent`: Each simulated client samples its private string from a Zipf distribution over strings with parameter `zipf_exponent` and support `unique_buckets`. diff --git a/src/bin/config.json b/src/bin/config.json index 894a8d0..caa0b7c 100644 --- a/src/bin/config.json +++ b/src/bin/config.json @@ -3,7 +3,8 @@ "threshold": 0.01, "server_0": "0.0.0.0:8000", "server_1": "0.0.0.0:8001", - "addkey_batch_size": 100, + "add_key_batch_size": 1000, + "flp_batch_size": 10000, "unique_buckets": 10, "zipf_exponent": 1.03 } diff --git a/src/bin/leader.rs b/src/bin/leader.rs index 1613eb9..23a29a4 100644 --- a/src/bin/leader.rs +++ b/src/bin/leader.rs @@ -85,18 +85,18 @@ async fn reset_servers( let req = ResetRequest { verify_key: *verify_key, }; - let response_0 = client_0.reset(long_context(), req.clone()); - let response_1 = client_1.reset(long_context(), req); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.reset(long_context(), req.clone()); + let resp_1 = client_1.reset(long_context(), req); + try_join!(resp_0, resp_1).unwrap(); Ok(()) } async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> { let req = TreeInitRequest {}; - let response_0 = client_0.tree_init(long_context(), req.clone()); - let response_1 = client_1.tree_init(long_context(), req); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.tree_init(long_context(), req.clone()); + let resp_1 = client_1.tree_init(long_context(), req); + try_join!(resp_0, resp_1).unwrap(); Ok(()) } @@ -141,55 +141,70 @@ async fn add_keys( flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone()); } - let response_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 }); - let response_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 }); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 }); + let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 }); + try_join!(resp_0, resp_1).unwrap(); - let response_0 = client_0.add_all_flp_proof_shares( + let resp_0 = client_0.add_all_flp_proof_shares( long_context(), AddFLPsRequest { flp_proof_shares: flp_proof_shares_0, }, ); - let response_1 = client_1.add_all_flp_proof_shares( + let resp_1 = client_1.add_all_flp_proof_shares( long_context(), AddFLPsRequest { flp_proof_shares: flp_proof_shares_1, }, ); - try_join!(response_0, response_1).unwrap(); + try_join!(resp_0, resp_1).unwrap(); Ok(()) } -async fn run_flp_queries(client_0: &Client, client_1: &Client) -> io::Result<()> { - let req = RunFlpQueriesRequest {}; - let response_0 = client_0.run_flp_queries(long_context(), req.clone()); - let response_1 = client_1.run_flp_queries(long_context(), req); - let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(response_0, response_1).unwrap(); - - assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len()); - +async fn run_flp_queries( + cfg: &config::Config, + client_0: &Client, + client_1: &Client, + num_clients: usize, +) -> io::Result<()> { + // Receive FLP query responses in chunks of cfg.add_key_batch_size to avoid having huge RPC messages. let count = Count::new(); - let keep = flp_verifier_shares_0 - .par_iter() - .zip(flp_verifier_shares_1.par_iter()) - .map(|(flp_verifier_share_0, flp_verifier_share_1)| { - let flp_verifier = flp_verifier_share_0 + let mut keep = vec![]; + let mut start = 0; + while start < num_clients { + let end = std::cmp::min(num_clients, start + cfg.add_key_batch_size); + + let req = RunFlpQueriesRequest { start, end }; + let resp_0 = client_0.run_flp_queries(long_context(), req.clone()); + let resp_1 = client_1.run_flp_queries(long_context(), req); + let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(resp_0, resp_1).unwrap(); + debug_assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len()); + + keep.extend( + flp_verifier_shares_0 .par_iter() - .zip(flp_verifier_share_1.par_iter()) - .map(|(&v1, &v2)| v1 + v2) - .collect::>(); + .zip(flp_verifier_shares_1.par_iter()) + .map(|(flp_verifier_share_0, flp_verifier_share_1)| { + let flp_verifier = flp_verifier_share_0 + .par_iter() + .zip(flp_verifier_share_1.par_iter()) + .map(|(&v1, &v2)| v1 + v2) + .collect::>(); + + count.decide(&flp_verifier).unwrap() + }) + .collect::>(), + ); - count.decide(&flp_verifier).unwrap() - }) - .collect::>(); + start += cfg.add_key_batch_size; + } // Tree prune let req = ApplyFLPResultsRequest { keep }; - let response_0 = client_0.apply_flp_results(long_context(), req.clone()); - let response_1 = client_1.apply_flp_results(long_context(), req); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.apply_flp_results(long_context(), req.clone()); + let resp_1 = client_1.apply_flp_results(long_context(), req); + try_join!(resp_0, resp_1).unwrap(); Ok(()) } @@ -213,10 +228,10 @@ async fn run_level( is_last, }; - let response_0 = client_0.tree_crawl(long_context(), req.clone()); - let response_1 = client_1.tree_crawl(long_context(), req); + let resp_0 = client_0.tree_crawl(long_context(), req.clone()); + let resp_1 = client_1.tree_crawl(long_context(), req); let ((cnt_values_0, mt_root_0, indices_0), (cnt_values_1, mt_root_1, indices_1)) = - try_join!(response_0, response_1).unwrap(); + try_join!(resp_0, resp_1).unwrap(); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); keep = @@ -254,9 +269,9 @@ async fn run_level( // Tree prune let req = TreePruneRequest { keep }; - let response_0 = client_0.tree_prune(long_context(), req.clone()); - let response_1 = client_1.tree_prune(long_context(), req); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.tree_prune(long_context(), req.clone()); + let resp_1 = client_1.tree_prune(long_context(), req); + try_join!(resp_0, resp_1).unwrap(); Ok(()) } @@ -270,10 +285,9 @@ async fn run_level_last( let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64); let req = TreeCrawlLastRequest {}; - let response_0 = client_0.tree_crawl_last(long_context(), req.clone()); - let response_1 = client_1.tree_crawl_last(long_context(), req); - let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) = - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.tree_crawl_last(long_context(), req.clone()); + let resp_1 = client_1.tree_crawl_last(long_context(), req); + let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) = try_join!(resp_0, resp_1).unwrap(); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); assert_eq!(hashes_0.len(), hashes_1.len()); @@ -289,14 +303,14 @@ async fn run_level_last( // Tree prune let req = TreePruneRequest { keep }; - let response_0 = client_0.tree_prune(long_context(), req.clone()); - let response_1 = client_1.tree_prune(long_context(), req); - try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.tree_prune(long_context(), req.clone()); + let resp_1 = client_1.tree_prune(long_context(), req); + try_join!(resp_0, resp_1).unwrap(); let req = FinalSharesRequest {}; - let response_0 = client_0.final_shares(long_context(), req.clone()); - let response_1 = client_1.final_shares(long_context(), req); - let (shares_0, shares_1) = try_join!(response_0, response_1).unwrap(); + let resp_0 = client_0.final_shares(long_context(), req.clone()); + let resp_1 = client_1.final_shares(long_context(), req); + let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap(); for res in &collect::KeyCollection::::final_values(&shares_0, &shares_1) { let bits = mastic::bits_to_bitstring(&res.path); if res.value > Field64::from(0) { @@ -348,7 +362,7 @@ async fn main() -> io::Result<()> { let mut responses = vec![]; for _ in 0..reqs_in_flight { - let this_batch = std::cmp::min(left_to_go, cfg.addkey_batch_size); + let this_batch = std::cmp::min(left_to_go, cfg.add_key_batch_size); left_to_go -= this_batch; if this_batch > 0 { @@ -371,7 +385,7 @@ async fn main() -> io::Result<()> { for level in 0..bit_len - 1 { let start_level = Instant::now(); if level == 0 { - run_flp_queries(&client_0, &client_1).await?; + run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?; } run_level(&cfg, &client_0, &client_1, num_clients).await?; println!( diff --git a/src/bin/server.rs b/src/bin/server.rs index e90d50b..12f6668 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -90,11 +90,14 @@ impl Collector for CollectorServer { async fn run_flp_queries( self, _: context::Context, - _req: RunFlpQueriesRequest, + req: RunFlpQueriesRequest, ) -> Vec> { let mut coll = self.arc.lock().unwrap(); + let start = req.start; + let end = req.end; + assert!(start < end); - coll.run_flp_queries() + coll.run_flp_queries(start, end) } async fn apply_flp_results(self, _: context::Context, req: ApplyFLPResultsRequest) -> String { diff --git a/src/collect.rs b/src/collect.rs index 762da50..b435fed 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -137,7 +137,7 @@ where child } - pub fn run_flp_queries(&mut self) -> Vec> { + pub fn run_flp_queries(&mut self, start: usize, end: usize) -> Vec> { let level = self.frontier[0].path.len(); assert_eq!(level, 0); @@ -156,6 +156,7 @@ where .key_values .par_iter() .enumerate() + .filter(|(client_index, _)| *client_index >= start && *client_index < end) .map(|(client_index, _)| { let y_p0 = node_left.key_values[client_index]; let y_p1 = node_right.key_values[client_index]; diff --git a/src/config.rs b/src/config.rs index 9f825c3..49c074b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,7 +5,8 @@ use serde_json::Value; pub struct Config { pub data_bytes: usize, - pub addkey_batch_size: usize, + pub add_key_batch_size: usize, + pub flp_batch_size: usize, pub unique_buckets: usize, pub threshold: f64, pub zipf_exponent: f64, @@ -22,9 +23,8 @@ pub fn get_config(filename: &str) -> Config { let v: Value = serde_json::from_str(json_data).expect("Cannot parse JSON config"); let data_bytes: usize = v["data_bytes"].as_u64().expect("Can't parse data_bytes") as usize; - let addkey_batch_size: usize = v["addkey_batch_size"] - .as_u64() - .expect("Can't parse addkey_batch_size") as usize; + let add_key_batch_size: usize = v["add_key_batch_size"].as_u64().expect("Can't parse add_key_batch_size") as usize; + let flp_batch_size: usize = v["flp_batch_size"].as_u64().expect("Can't parse flp_batch_size") as usize; let unique_buckets: usize = v["unique_buckets"] .as_u64() .expect("Can't parse unique_buckets") as usize; @@ -37,7 +37,8 @@ pub fn get_config(filename: &str) -> Config { Config { data_bytes, - addkey_batch_size, + add_key_batch_size, + flp_batch_size, unique_buckets, threshold, zipf_exponent, diff --git a/src/rpc.rs b/src/rpc.rs index 698d93c..83b55c1 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -37,7 +37,10 @@ pub struct TreeCrawlRequest { pub struct TreeCrawlLastRequest {} #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct RunFlpQueriesRequest {} +pub struct RunFlpQueriesRequest { + pub start: usize, + pub end: usize, +} #[derive(Clone, Debug, Serialize, Deserialize)] pub struct TreePruneRequest {