Skip to content

Commit

Permalink
Replace SHA with BLAKE3
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Dec 7, 2023
1 parent 1cebe78 commit 44331b2
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 121 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ license = "MIT"
[dependencies]
aes = "0.8.1"
bincode = "1.3.3"
blake3 = { version = "1.5.0", features = ["rayon"]}
clap = "2.0"
ctr = "0.9.2"
futures = "0.3.28"
Expand All @@ -28,7 +29,6 @@ rayon = "1.8.0"
rs_merkle = "1.2"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.10.8"
tarpc = { version = "0.30.0", features = ["full", "serde-transport", "tcp", "tokio1"] }
tokio = { version = "1.32.0", features = ["full", "macros"] }
zipf = "7.0.1"
Expand Down
3 changes: 2 additions & 1 deletion src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use mastic::{
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
HASH_SIZE,
};
use prio::field::Field64;
use tarpc::{
Expand Down Expand Up @@ -122,7 +123,7 @@ impl Collector for CollectorServer {
res
}

async fn get_proofs(self, _: context::Context, req: GetProofsRequest) -> Vec<[u8; 32]> {
async fn get_proofs(self, _: context::Context, req: GetProofsRequest) -> Vec<[u8; HASH_SIZE]> {
let coll = self.arc.lock().unwrap();
debug_assert!(req.start < req.end);

Expand Down
134 changes: 60 additions & 74 deletions src/collect.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
use blake3::hash;
use prio::{
flp::{types::Count, Type},
vdaf::xof::{IntoFieldVec, Xof, XofShake128},
};
use rayon::prelude::*;
use rs_merkle::{Hasher, MerkleTree};
use serde::{Deserialize, Serialize};
use sha2::{digest::FixedOutput, Digest, Sha256};

use crate::{dpf, prg, xor_in_place, xor_vec};
use crate::{dpf, prg, xor_in_place, xor_vec, HASH_SIZE};

#[derive(Clone)]
pub struct Sha256Algorithm {}
pub struct HashAlg {}

impl Hasher for Sha256Algorithm {
type Hash = [u8; 32];
impl Hasher for HashAlg {
type Hash = [u8; HASH_SIZE];

fn hash(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
<[u8; 32]>::from(hasher.finalize_fixed())
fn hash(data: &[u8]) -> [u8; HASH_SIZE] {
hash(data).as_bytes()[0..HASH_SIZE].try_into().unwrap()
}
}

Expand All @@ -43,7 +41,7 @@ pub struct KeyCollection<T> {
frontier: Vec<TreeNode<T>>,
prev_frontier: Vec<TreeNode<T>>,
count: Count<T>,
final_proofs: Vec<[u8; 32]>,
final_proofs: Vec<[u8; HASH_SIZE]>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -55,12 +53,12 @@ pub struct Result<T> {
impl<T> KeyCollection<T>
where
T: prio::field::FieldElement
+ prio::field::FftFriendlyFieldElement
+ std::fmt::Debug
+ std::cmp::PartialOrd
+ Send
+ Sync
+ prg::FromRng
+ prio::field::FftFriendlyFieldElement
+ 'static,
u64: From<T>,
{
Expand Down Expand Up @@ -120,12 +118,11 @@ where
.unzip();

let mut child_val = T::zero();
for (i, &v) in key_values.iter().enumerate() {
// Add in only live values
if self.keys[i].0 {
child_val.add_assign(v);
}
}
key_values
.iter()
.zip(&self.keys)
.filter(|&(_, key)| key.0)
.for_each(|(&v, _)| child_val.add_assign(v));

let mut child = TreeNode::<T> {
path: parent.path.clone(),
Expand Down Expand Up @@ -224,22 +221,6 @@ where
.map(|node| node.value)
.collect::<Vec<T>>();

// Combine the multiple proofs for each client into a single proof for each client.
let num_clients = next_frontier.get(0).map_or(0, |node| node.key_states.len());
let mut key_proofs: Vec<_> = vec![[0u8; 32]; num_clients];
key_proofs
.par_iter_mut()
.enumerate()
.zip_eq(&self.keys)
.for_each(|((proof_index, proof), key)| {
if key.0 {
// If the client is honest.
for node in next_frontier.iter() {
xor_in_place(proof, &node.key_states[proof_index].proof);
}
}
});

// For all prefixes, compute the checks for each client.
let all_y_checks = self
.frontier
Expand Down Expand Up @@ -290,32 +271,38 @@ where
})
.collect::<Vec<_>>();

// Now, we combine all the checks for each client into a single check for each client.
let key_checks = all_y_checks[0] // parallelize the clients
let combined_hashes = self
.keys
.par_iter()
.enumerate()
.zip_eq(&self.keys)
.map(|((client_index, _), key)| {
let mut hasher = Sha256::new();
if key.0 {
// If the client is honest.
all_y_checks.iter().for_each(|checks_for_prefix| {
let mut bytes = vec![];
checks_for_prefix[client_index].encode(&mut bytes);
hasher.update(bytes);
});
}
hasher.finalize().to_vec()
.filter(|(_, key)| key.0)
.map(|(client_index, _)| {
// Combine the multiple proofs that each client has for each prefix into a single
// proof for each client.
let mut proof = [0u8; HASH_SIZE];
next_frontier.iter().for_each(|node| {
xor_in_place(&mut proof, &node.key_states[client_index].proof);
});

// Combine all the checks that each client has for each prefix into a single check
// for each client.
let mut check = [0u8; 8];
all_y_checks.iter().for_each(|checks_for_prefix| {
xor_in_place(&mut check, &checks_for_prefix[client_index].get_encoded());
});

xor_vec(
&proof,
hash(&check).as_bytes()[0..HASH_SIZE].try_into().unwrap(),
)
.try_into()
.unwrap()
})
.collect::<Vec<_>>();

debug_assert_eq!(key_proofs.len(), key_checks.len());

let combined_hashes = key_proofs
.par_iter()
.zip(key_checks.par_iter())
.map(|(proof, check)| xor_vec(proof, check).try_into().unwrap())
.collect::<Vec<[u8; 32]>>();
.collect::<Vec<[u8; HASH_SIZE]>>();
debug_assert_eq!(
self.keys.iter().filter(|&key| key.0).count(),
combined_hashes.len()
);

// Compute the Merkle tree based on y_checks for each client and the proofs.
// If we are at the last level, we only need to compute the root as the malicious clients
Expand All @@ -330,21 +317,21 @@ where
let mut mtree_roots = vec![];
let mut mtree_indices = vec![];
if split_by == 1 {
let mt = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[0]);
let mt = MerkleTree::<HashAlg>::from_leaves(chunks_list[0]);
let root = mt.root().unwrap();
mtree_roots.push(root.to_vec());
mtree_indices.push(0);
} else {
for &i in malicious {
let mt_left = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[i * 2]);
let mt_left = MerkleTree::<HashAlg>::from_leaves(chunks_list[i * 2]);
let root_left = mt_left.root().unwrap();
mtree_roots.push(root_left.to_vec());
mtree_indices.push(i * 2);

if i * 2 + 1 >= chunks_list.len() {
continue;
}
let mt_right = MerkleTree::<Sha256Algorithm>::from_leaves(chunks_list[i * 2 + 1]);
let mt_right = MerkleTree::<HashAlg>::from_leaves(chunks_list[i * 2 + 1]);
let root_right = mt_right.root().unwrap();
mtree_roots.push(root_right.to_vec());
mtree_indices.push(i * 2 + 1);
Expand All @@ -370,21 +357,20 @@ where
})
.collect::<Vec<TreeNode<T>>>();

let num_clients = next_frontier.get(0).map_or(0, |node| node.key_states.len());
self.final_proofs = vec![[0u8; 32]; num_clients];
self.final_proofs
.par_iter_mut()
self.final_proofs = self
.keys
.par_iter()
.enumerate()
.zip_eq(&self.keys)
.for_each(|((proof_index, proof), key)| {
if key.0 {
// If the client is honest.
for node in next_frontier.iter() {
xor_in_place(proof, &node.key_states[proof_index].proof);
}
}
});

.filter(|(_, key)| key.0) // If the client is honest.
.map(|(proof_index, _)| {
let mut proof = [0u8; HASH_SIZE];
next_frontier.iter().for_each(|node| {
xor_in_place(&mut proof, &node.key_states[proof_index].proof);
});

proof
})
.collect::<Vec<_>>();
self.frontier = next_frontier;

// These are summed evaluations y for different prefixes.
Expand All @@ -394,7 +380,7 @@ where
.collect::<Vec<T>>()
}

pub fn get_proofs(&self, start: usize, end: usize) -> Vec<[u8; 32]> {
pub fn get_proofs(&self, start: usize, end: usize) -> Vec<[u8; HASH_SIZE]> {
let mut proofs = Vec::new();
if end > start && end <= self.final_proofs.len() {
proofs.extend_from_slice(&self.final_proofs[start..end]);
Expand Down
Loading

0 comments on commit 44331b2

Please sign in to comment.