Skip to content

Commit

Permalink
feat: keccak + weierstrass trace gen opt (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
ratankaliani authored Feb 29, 2024
1 parent 642f8ee commit c2e012f
Show file tree
Hide file tree
Showing 9 changed files with 243 additions and 182 deletions.
106 changes: 53 additions & 53 deletions core/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,82 +1,82 @@
[package]
edition = "2021"
name = "sp1-core"
version = "0.1.0"
edition = "2021"

[dependencies]
bincode = "1.3.3"
serde = { version = "1.0", features = ["derive"] }
p3-field = { workspace = true }
p3-air = { workspace = true }
p3-matrix = { workspace = true }
p3-baby-bear = { workspace = true }
elf = "0.7.4"
sp1-derive = { path = "../derive" }
p3-commit = { workspace = true }
p3-challenger = { workspace = true }
p3-dft = { workspace = true }
p3-fri = { workspace = true }
p3-goldilocks = { workspace = true }
p3-keccak = { workspace = true }
p3-keccak-air = { workspace = true }
p3-mds = { workspace = true }
p3-merkle-tree = { workspace = true }
p3-poseidon2 = { workspace = true }
p3-blake3 = { workspace = true }
p3-symmetric = { workspace = true }
p3-uni-stark = { workspace = true }
p3-maybe-rayon = { workspace = true }
p3-util = { workspace = true }
itertools = "0.12.0"
rrs-lib = { git = "https://github.com/GregAC/rrs.git" }
lazy_static = "1.4"
log = "0.4.20"
num = { version = "0.4.1" }
nohash-hasher = "0.2.0"
lazy_static = "1.4"
num = {version = "0.4.1"}
p3-air = {workspace = true}
p3-baby-bear = {workspace = true}
p3-blake3 = {workspace = true}
p3-challenger = {workspace = true}
p3-commit = {workspace = true}
p3-dft = {workspace = true}
p3-field = {workspace = true}
p3-fri = {workspace = true}
p3-goldilocks = {workspace = true}
p3-keccak = {workspace = true}
p3-keccak-air = {workspace = true}
p3-matrix = {workspace = true}
p3-maybe-rayon = {workspace = true}
p3-mds = {workspace = true}
p3-merkle-tree = {workspace = true}
p3-poseidon2 = {workspace = true}
p3-symmetric = {workspace = true}
p3-uni-stark = {workspace = true}
p3-util = {workspace = true}
rrs-lib = {git = "https://github.com/GregAC/rrs.git"}
serde = {version = "1.0", features = ["derive"]}
sp1-derive = {path = "../derive"}

tracing = "0.1.40"
tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] }
tracing-forest = { version = "0.1.6", features = ["ansi", "smallvec"] }
tracing-log = "0.2.0"
clap = { version = "4.4.0", features = ["derive"] }
curve25519-dalek = { version = "=4.0.0" }
hex = "0.4.3"
tempfile = "3.9.0"
flate2 = "1.0.28"
size = "0.4.1"
serde_json = { version = "1.0.113", default-features = false, features = [
"alloc",
] }
k256 = { version = "0.13.3", features = ["expose-field"] }
elliptic-curve = "0.13.8"
anyhow = "1.0.79"
serial_test = "3.0.0"
petgraph = "0.6.4"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }
hashbrown = "0.14.3"
num_cpus = "1.16.0"
blake3 = "1.5"
blake3-zkvm = { git = "https://github.com/sp1-patches/BLAKE3.git", branch = "patch-blake3_zkvm/v.1.0.0" }
blake3-zkvm = {git = "https://github.com/sp1-patches/BLAKE3.git", branch = "patch-blake3_zkvm/v.1.0.0"}
cfg-if = "1.0.0"
clap = {version = "4.4.0", features = ["derive"]}
curve25519-dalek = {version = "=4.0.0"}
elliptic-curve = "0.13.8"
flate2 = "1.0.28"
hashbrown = "0.14.3"
hex = "0.4.3"
k256 = {version = "0.13.3", features = ["expose-field"]}
num_cpus = "1.16.0"
petgraph = "0.6.4"
serde_json = {version = "1.0.113", default-features = false, features = [
"alloc",
]}
serial_test = "3.0.0"
size = "0.4.1"
tempfile = "3.9.0"
tiny-keccak = {version = "2.0.2", features = ["keccak"]}
tracing = "0.1.40"
tracing-forest = {version = "0.1.6", features = ["ansi", "smallvec"]}
tracing-log = "0.2.0"
tracing-subscriber = {version = "0.3.17", features = ["std", "env-filter"]}

[dev-dependencies]
criterion = "0.5.1"
num = { version = "0.4.1", features = ["rand"] }
num = {version = "0.4.1", features = ["rand"]}
rand = "0.8.5"


[features]
perf = ["parallel"]
parallel = ["p3-maybe-rayon/parallel", "p3-blake3/parallel"]
default = ["perf"]
debug = ["parallel"]
debug-proof = ["parallel", "perf"]
serial = []
default = ["perf"]
keccak = []
neon = ["p3-blake3/neon"]
parallel = ["p3-maybe-rayon/parallel", "p3-blake3/parallel"]
perf = ["parallel"]
serial = []

[[bench]]
name = "main"
harness = false
name = "main"

[lib]
bench = false
186 changes: 110 additions & 76 deletions core/src/syscall/precompiles/keccak256/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use alloc::vec::Vec;
use p3_field::PrimeField32;
use p3_keccak_air::{generate_trace_rows, NUM_KECCAK_COLS, NUM_ROUNDS};
use p3_matrix::dense::RowMajorMatrix;
use p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice};
use tracing::instrument;

use crate::{
air::MachineAir, runtime::ExecutionRecord, syscall::precompiles::keccak256::STATE_SIZE,
Expand All @@ -20,6 +22,7 @@ impl<F: PrimeField32> MachineAir<F> for KeccakPermuteChip {
"KeccakPermute".to_string()
}

#[instrument(name = "generate KeccakPermute trace", skip_all)]
fn generate_trace(
&self,
input: &ExecutionRecord,
Expand All @@ -39,84 +42,115 @@ impl<F: PrimeField32> MachineAir<F> for KeccakPermuteChip {
num_total_permutations = 1;
}

let mut new_field_events = Vec::new();
let mut rows = Vec::new();
for permutation_num in 0..num_total_permutations {
let is_real_permutation = permutation_num < num_real_permutations;

let event = if is_real_permutation {
Some(&input.keccak_permute_events[permutation_num])
} else {
None
};

let perm_input: [u64; STATE_SIZE] = if is_real_permutation {
event.unwrap().pre_state
} else {
[0; STATE_SIZE]
};

let start_clk = if is_real_permutation {
event.unwrap().clk
} else {
0
};

let shard = if is_real_permutation {
event.unwrap().shard
} else {
0
};

// First get the trace for the plonky3 keccak air.
let p3_keccak_trace = generate_trace_rows::<F>(vec![perm_input]);

// Create all the rows for the permutation.
for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) {
let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS];

// Copy the keccack row into the trace_row
row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row);

let mem_row = &mut row[NUM_KECCAK_COLS..];

let col: &mut KeccakMemCols<F> = mem_row.borrow_mut();
col.shard = F::from_canonical_u32(shard);
col.clk = F::from_canonical_u32(start_clk + i as u32 * 4);

// if this is the first row, then populate read memory accesses
if i == 0 && is_real_permutation {
for (j, read_record) in event.unwrap().state_read_records.iter().enumerate() {
col.state_mem[j].populate_read(*read_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

// if this is the last row, then populate write memory accesses
let last_row_num = NUM_ROUNDS - 1;
if i == last_row_num && is_real_permutation {
for (j, write_record) in event.unwrap().state_write_records.iter().enumerate() {
col.state_mem[j].populate_write(*write_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

col.is_real = F::from_bool(is_real_permutation);

rows.push(row);

if rows.len() == num_rows {
break;
}
}
let chunk_size = std::cmp::max(num_total_permutations / num_cpus::get(), 1);

// Use par_chunks to generate the trace in parallel.
let rows_and_records = (0..num_total_permutations)
.collect::<Vec<_>>()
.par_chunks(chunk_size)
.map(|chunk| {
let mut record = ExecutionRecord::default();
let mut new_field_events = Vec::new();

let rows = chunk
.iter()
.flat_map(|permutation_num| {
let mut rows = Vec::new();

let is_real_permutation = *permutation_num < num_real_permutations;

let event = if is_real_permutation {
Some(&input.keccak_permute_events[*permutation_num])
} else {
None
};

let perm_input: [u64; STATE_SIZE] = if is_real_permutation {
event.unwrap().pre_state
} else {
[0; STATE_SIZE]
};

let start_clk = if is_real_permutation {
event.unwrap().clk
} else {
0
};

let shard = if is_real_permutation {
event.unwrap().shard
} else {
0
};

// First get the trace for the plonky3 keccak air.
let p3_keccak_trace = generate_trace_rows::<F>(vec![perm_input]);

// Create all the rows for the permutation.
for (i, p3_keccak_row) in (0..NUM_ROUNDS).zip(p3_keccak_trace.rows()) {
let row_idx = permutation_num * NUM_ROUNDS + i;

let mut row = [F::zero(); NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS];

// Copy the keccack row into the trace_row
row[..NUM_KECCAK_COLS].copy_from_slice(p3_keccak_row);

let mem_row = &mut row[NUM_KECCAK_COLS..];

let col: &mut KeccakMemCols<F> = mem_row.borrow_mut();
col.shard = F::from_canonical_u32(shard);
col.clk = F::from_canonical_u32(start_clk + i as u32 * 4);

// if this is the first row, then populate read memory accesses
if i == 0 && is_real_permutation {
for (j, read_record) in
event.unwrap().state_read_records.iter().enumerate()
{
col.state_mem[j]
.populate_read(*read_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

// if this is the last row, then populate write memory accesses
let last_row_num = NUM_ROUNDS - 1;
if i == last_row_num && is_real_permutation {
for (j, write_record) in
event.unwrap().state_write_records.iter().enumerate()
{
col.state_mem[j]
.populate_write(*write_record, &mut new_field_events);
}

col.state_addr = F::from_canonical_u32(event.unwrap().state_addr);
col.do_memory_check = F::one();
}

col.is_real = F::from_bool(is_real_permutation);

rows.push(row);

if row_idx == num_rows - 1 {
break;
}
}
rows
})
.collect::<Vec<_>>();
record.add_field_events(&new_field_events);
(rows, record)
})
.collect::<Vec<_>>();

// Generate the trace rows for each event.
let mut rows: Vec<[F; NUM_KECCAK_COLS + NUM_KECCAK_MEM_COLS]> = vec![];
for mut row_and_record in rows_and_records {
rows.extend(row_and_record.0);
output.append(&mut row_and_record.1);
}

output.add_field_events(&new_field_events);

// Convert the trace to a row major matrix.
RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
Expand Down
Loading

0 comments on commit c2e012f

Please sign in to comment.