Skip to content

Commit

Permalink
Receive verifier shares in RPC batches
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Nov 1, 2023
1 parent fee03a6 commit 86edc8b
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 64 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ system. An example of one such file is in `src/bin/config.json`. The contents of
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"batch_size": 1000,
"unique_buckets": 10,
"zipf_exponent": 1.03
}
Expand All @@ -72,7 +72,7 @@ The parameters are:
clients hold.
* `server0`, `server1`, and `server2`: The `IP:port` of tuple for the two servers. The servers can
run on different IP addresses, but these IPs must be publicly addressable.
* `addkey_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
* `batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
library has an annoying limit on the size of each RPC request, so you cannot set these values too
large.
* `unique_buckets` and `zipf_exponent`: Each simulated client samples its private string from a Zipf
Expand Down
2 changes: 1 addition & 1 deletion src/bin/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"batch_size": 1000,
"unique_buckets": 10,
"zipf_exponent": 1.03
}
118 changes: 66 additions & 52 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,18 @@ async fn reset_servers(
let req = ResetRequest {
verify_key: *verify_key,
};
let response_0 = client_0.reset(long_context(), req.clone());
let response_1 = client_1.reset(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.reset(long_context(), req.clone());
let resp_1 = client_1.reset(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = TreeInitRequest {};
let response_0 = client_0.tree_init(long_context(), req.clone());
let response_1 = client_1.tree_init(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_init(long_context(), req.clone());
let resp_1 = client_1.tree_init(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand Down Expand Up @@ -141,55 +141,70 @@ async fn add_keys(
flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone());
}

let response_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let response_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(resp_0, resp_1).unwrap();

let response_0 = client_0.add_all_flp_proof_shares(
let resp_0 = client_0.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_0,
},
);
let response_1 = client_1.add_all_flp_proof_shares(
let resp_1 = client_1.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_1,
},
);
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn run_flp_queries(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = RunFlpQueriesRequest {};
let response_0 = client_0.run_flp_queries(long_context(), req.clone());
let response_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(response_0, response_1).unwrap();

assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

async fn run_flp_queries(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
num_clients: usize,
) -> io::Result<()> {
// Receive FLP query responses in chunks of cfg.batch_size to avoid having huge RPC messages.
let count = Count::new();
let keep = flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
let mut keep = vec![];
let mut start = 0;
while start < num_clients {
let end = std::cmp::min(num_clients, start + cfg.batch_size);

let req = RunFlpQueriesRequest { start, end };
let resp_0 = client_0.run_flp_queries(long_context(), req.clone());
let resp_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(resp_0, resp_1).unwrap();
debug_assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

keep.extend(
flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>(),
);

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>();
start += cfg.batch_size;
}

// Tree prune
let req = ApplyFLPResultsRequest { keep };
let response_0 = client_0.apply_flp_results(long_context(), req.clone());
let response_1 = client_1.apply_flp_results(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.apply_flp_results(long_context(), req.clone());
let resp_1 = client_1.apply_flp_results(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -213,10 +228,10 @@ async fn run_level(
is_last,
};

let response_0 = client_0.tree_crawl(long_context(), req.clone());
let response_1 = client_1.tree_crawl(long_context(), req);
let resp_0 = client_0.tree_crawl(long_context(), req.clone());
let resp_1 = client_1.tree_crawl(long_context(), req);
let ((cnt_values_0, mt_root_0, indices_0), (cnt_values_1, mt_root_1, indices_1)) =
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
keep =
Expand Down Expand Up @@ -254,9 +269,9 @@ async fn run_level(

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -270,10 +285,9 @@ async fn run_level_last(
let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64);

let req = TreeCrawlLastRequest {};
let response_0 = client_0.tree_crawl_last(long_context(), req.clone());
let response_1 = client_1.tree_crawl_last(long_context(), req);
let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) =
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_crawl_last(long_context(), req.clone());
let resp_1 = client_1.tree_crawl_last(long_context(), req);
let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) = try_join!(resp_0, resp_1).unwrap();

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
assert_eq!(hashes_0.len(), hashes_1.len());
Expand All @@ -289,14 +303,14 @@ async fn run_level_last(

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

let req = FinalSharesRequest {};
let response_0 = client_0.final_shares(long_context(), req.clone());
let response_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.final_shares(long_context(), req.clone());
let resp_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap();
for res in &collect::KeyCollection::<Field64>::final_values(&shares_0, &shares_1) {
let bits = mastic::bits_to_bitstring(&res.path);
if res.value > Field64::from(0) {
Expand Down Expand Up @@ -348,7 +362,7 @@ async fn main() -> io::Result<()> {
let mut responses = vec![];

for _ in 0..reqs_in_flight {
let this_batch = std::cmp::min(left_to_go, cfg.addkey_batch_size);
let this_batch = std::cmp::min(left_to_go, cfg.batch_size);
left_to_go -= this_batch;

if this_batch > 0 {
Expand All @@ -371,7 +385,7 @@ async fn main() -> io::Result<()> {
for level in 0..bit_len - 1 {
let start_level = Instant::now();
if level == 0 {
run_flp_queries(&client_0, &client_1).await?;
run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?;
}
run_level(&cfg, &client_0, &client_1, num_clients).await?;
println!(
Expand Down
7 changes: 5 additions & 2 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,14 @@ impl Collector for CollectorServer {
async fn run_flp_queries(
self,
_: context::Context,
_req: RunFlpQueriesRequest,
req: RunFlpQueriesRequest,
) -> Vec<Vec<Field64>> {
let mut coll = self.arc.lock().unwrap();
let start = req.start;
let end = req.end;
assert!(start < end);

coll.run_flp_queries()
coll.run_flp_queries(start, end)
}

async fn apply_flp_results(self, _: context::Context, req: ApplyFLPResultsRequest) -> String {
Expand Down
3 changes: 2 additions & 1 deletion src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ where
child
}

pub fn run_flp_queries(&mut self) -> Vec<Vec<T>> {
pub fn run_flp_queries(&mut self, start: usize, end: usize) -> Vec<Vec<T>> {
let level = self.frontier[0].path.len();
assert_eq!(level, 0);

Expand All @@ -156,6 +156,7 @@ where
.key_values
.par_iter()
.enumerate()
.filter(|(client_index, _)| *client_index >= start && *client_index < end)
.map(|(client_index, _)| {
let y_p0 = node_left.key_values[client_index];
let y_p1 = node_right.key_values[client_index];
Expand Down
8 changes: 3 additions & 5 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use serde_json::Value;

pub struct Config {
pub data_bytes: usize,
pub addkey_batch_size: usize,
pub batch_size: usize,
pub unique_buckets: usize,
pub threshold: f64,
pub zipf_exponent: f64,
Expand All @@ -22,9 +22,7 @@ pub fn get_config(filename: &str) -> Config {
let v: Value = serde_json::from_str(json_data).expect("Cannot parse JSON config");

let data_bytes: usize = v["data_bytes"].as_u64().expect("Can't parse data_bytes") as usize;
let addkey_batch_size: usize = v["addkey_batch_size"]
.as_u64()
.expect("Can't parse addkey_batch_size") as usize;
let batch_size: usize = v["batch_size"].as_u64().expect("Can't parse batch_size") as usize;
let unique_buckets: usize = v["unique_buckets"]
.as_u64()
.expect("Can't parse unique_buckets") as usize;
Expand All @@ -37,7 +35,7 @@ pub fn get_config(filename: &str) -> Config {

Config {
data_bytes,
addkey_batch_size,
batch_size,
unique_buckets,
threshold,
zipf_exponent,
Expand Down
5 changes: 4 additions & 1 deletion src/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ pub struct TreeCrawlRequest {
pub struct TreeCrawlLastRequest {}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RunFlpQueriesRequest {}
pub struct RunFlpQueriesRequest {
pub start: usize,
pub end: usize,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TreePruneRequest {
Expand Down

0 comments on commit 86edc8b

Please sign in to comment.