diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000000..5b3ca8c2a1 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +**/target +Dockerfile* +.dockerignore +.git +.gitignore +examples +tests +sp1up +book diff --git a/.github/workflows/docker-publish-gnark.yml b/.github/workflows/docker-publish-gnark.yml new file mode 100644 index 0000000000..c36a28d302 --- /dev/null +++ b/.github/workflows/docker-publish-gnark.yml @@ -0,0 +1,102 @@ +# Source: https://raw.githubusercontent.com/foundry-rs/foundry/master/.github/workflows/docker-publish.yml +name: docker-gnark + +on: + push: + tags: + - "v*.*.*" + schedule: + - cron: "0 0 * * *" + # Trigger without any parameters a proactive rebuild + workflow_dispatch: + inputs: + tags: + description: "Docker tag to push" + required: true + workflow_call: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: succinctlabs/sp1-gnark + +jobs: + container: + runs-on: ubuntu-latest + # https://docs.github.com/en/actions/reference/authentication-in-a-workflow + permissions: + id-token: write + packages: write + contents: read + timeout-minutes: 120 + steps: + - name: Checkout repository + id: checkout + uses: actions/checkout@v4 + + - name: Install Docker BuildX + uses: docker/setup-buildx-action@v2 + id: buildx + with: + install: true + + # Login against a Docker registry except on PR + # https://github.com/docker/login-action + - name: Log into registry ${{ env.REGISTRY }} + # Ensure this doesn't trigger on PR's + if: github.event_name != 'pull_request' + uses: docker/login-action@v2 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + # Extract metadata (tags, labels) for Docker + # https://github.com/docker/metadata-action + - name: Extract Docker metadata + id: meta + uses: docker/metadata-action@v4 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + + # Creates an additional 'latest' or 'nightly' tag + # If the job is triggered via cron schedule, tag nightly and nightly-{SHA} + # If the job is triggered via workflow dispatch and on a master branch, tag branch and latest + # Otherwise, just tag as the branch name + - name: Finalize Docker Metadata + id: docker_tagging + run: | + if [[ "${{ github.event_name }}" == 'workflow_dispatch' ]]; then + echo "manual trigger from workflow_dispatch, assigning tag ${{ github.event.inputs.tags }}" + echo "docker_tags=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ github.event.inputs.tags }}" >> $GITHUB_OUTPUT + elif [[ "${{ github.event_name }}" == 'schedule' ]]; then + echo "cron trigger, assigning nightly tag" + echo "docker_tags=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:nightly,${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:nightly-${GITHUB_SHA}" >> $GITHUB_OUTPUT + else + echo "Neither scheduled nor manual release from main branch. Just tagging as branch name" + echo "docker_tags=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${GITHUB_REF##*/}" >> $GITHUB_OUTPUT + fi + + # Log docker metadata to explicitly know what is being pushed + - name: Inspect Docker Metadata + run: | + echo "TAGS -> ${{ steps.docker_tagging.outputs.docker_tags }}" + echo "LABELS -> ${{ steps.meta.outputs.labels }}" + + # Build and push Docker image + # https://github.com/docker/build-push-action + # https://github.com/docker/build-push-action/blob/master/docs/advanced/cache.md + - name: Build and push Docker image + uses: docker/build-push-action@v3 + with: + context: . + file: ./Dockerfile.gnark-ffi + platforms: linux/amd64,linux/arm64 + push: true + tags: ${{ steps.docker_tagging.outputs.docker_tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max + build-args: | + BUILDTIME=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.created'] }} + VERSION=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.version'] }} + REVISION=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cfc5854ae2..a2547409d0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -14,37 +14,58 @@ on: - ".github/workflows/**" concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: - plonk: - name: Plonk - runs-on: runs-on,cpu=64,ram=256,family=m7i+m7a,hdd=80,image=ubuntu22-full-x64 + plonk: + name: Plonk Native + runs-on: runs-on,cpu=64,ram=256,family=m7i+m7a,hdd=80,image=ubuntu22-full-x64 + env: + CARGO_NET_GIT_FETCH_WITH_CLI: "true" + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Setup CI + uses: ./.github/actions/setup + + - name: Run cargo test + uses: actions-rs/cargo@v1 + with: + command: test + toolchain: nightly-2024-04-17 + args: --release -p sp1-sdk --features native-gnark -- test_e2e_prove_plonk --nocapture + env: + RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native + RUST_BACKTRACE: 1 + plonk-docker: + name: Plonk Docker + runs-on: runs-on,cpu=64,ram=256,family=m7i+m7a,hdd=80,image=ubuntu22-full-x64 + env: + CARGO_NET_GIT_FETCH_WITH_CLI: "true" + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Setup CI + uses: ./.github/actions/setup + + - name: Run cargo test + uses: actions-rs/cargo@v1 + with: + command: test + toolchain: nightly-2024-04-17 + args: --release -p sp1-sdk -- test_e2e_prove_plonk --nocapture env: - CARGO_NET_GIT_FETCH_WITH_CLI: "true" - steps: - - name: Checkout sources - uses: actions/checkout@v4 - - - name: Setup CI - uses: ./.github/actions/setup - - - name: Run cargo test - uses: actions-rs/cargo@v1 - with: - command: test - toolchain: nightly-2024-04-17 - args: --release -p sp1-sdk --features plonk -- test_e2e_prove_plonk --nocapture - env: - RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native - RUST_BACKTRACE: 1 - check-branch: - name: Check branch - runs-on: ubuntu-latest - steps: - - name: Check branch - if: github.head_ref != 'dev' - run: | - echo "ERROR: You can only merge to main from dev." - exit 1 \ No newline at end of file + RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native + RUST_BACKTRACE: 1 + check-branch: + name: Check branch + runs-on: ubuntu-latest + steps: + - name: Check branch + if: github.head_ref != 'dev' + run: | + echo "ERROR: You can only merge to main from dev." + exit 1 diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 120f0e981d..ac6a248770 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -48,7 +48,7 @@ jobs: with: command: test toolchain: nightly-2024-04-17 - args: --release --features plonk + args: --release --features native-gnark env: RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native RUST_BACKTRACE: 1 @@ -79,13 +79,38 @@ jobs: with: command: test toolchain: nightly-2024-04-17 - args: --release --features plonk + args: --release --features native-gnark env: RUSTFLAGS: -Copt-level=3 -Cdebug-assertions -Coverflow-checks=y -Cdebuginfo=0 -C target-cpu=native RUST_BACKTRACE: 1 FRI_QUERIES: 1 SP1_DEV: 1 + test-docker: + name: Test Docker + runs-on: runs-on,runner=64cpu-linux-arm64 + env: + CARGO_NET_GIT_FETCH_WITH_CLI: "true" + steps: + - name: Checkout sources + uses: actions/checkout@v4 + + - name: Setup CI + uses: ./.github/actions/setup + + - name: Build docker image + run: | + docker build -t sp1-gnark -f ./Dockerfile.gnark-ffi . + + - name: Run cargo test + uses: actions-rs/cargo@v1 + env: + SP1_GNARK_IMAGE: sp1-gnark + with: + command: test + toolchain: nightly-2024-04-17 + args: --release -p sp1-prover -- --exact tests::test_e2e + lint: name: Formatting & Clippy runs-on: runs-on,runner=8cpu-linux-x64 @@ -178,5 +203,6 @@ jobs: cargo add sp1-zkvm --path $GITHUB_WORKSPACE/zkvm/entrypoint cargo prove build cd ../script + cargo remove sp1-sdk cargo add sp1-sdk --path $GITHUB_WORKSPACE/sdk SP1_DEV=1 RUST_LOG=info cargo run --release diff --git a/.gitignore b/.gitignore index aab0a12de6..c50ea62a6e 100644 --- a/.gitignore +++ b/.gitignore @@ -12,8 +12,8 @@ pgo-data.profdata .DS_Store # Proofs -**/proof-with-pis.json -**/proof-with-io.json +**/proof-with-pis.bin +**/proof-with-io.bin # Benchmark benchmark.csv diff --git a/.vscode/settings.json b/.vscode/settings.json index 51f6473517..eb7b799168 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,7 +40,6 @@ // "examples/tendermint/script/Cargo.toml", // // Tests. // "examples/ed25519/Cargo.toml", - // "tests/blake3-compress/Cargo.toml", // "tests/cycle-tracker/Cargo.toml", // "tests/ecrecover/Cargo.toml", // "tests/ed-add/Cargo.toml", diff --git a/Cargo.lock b/Cargo.lock index b75a8e5669..adfcf808b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -78,9 +78,9 @@ checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "alloy-primitives" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db8aa973e647ec336810a9356af8aea787249c9d00b1525359f3db29a68d231b" +checksum = "f783611babedbbe90db3478c120fb5f5daacceffc210b39adc0af4fe0da70bad" dependencies = [ "alloy-rlp", "bytes", @@ -110,9 +110,9 @@ dependencies = [ [[package]] name = "alloy-sol-macro" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dbd17d67f3e89478c8a634416358e539e577899666c927bc3d2b1328ee9b6ca" +checksum = "4bad41a7c19498e3f6079f7744656328699f8ea3e783bdd10d85788cd439f572" dependencies = [ "alloy-sol-macro-expander", "alloy-sol-macro-input", @@ -124,13 +124,13 @@ dependencies = [ [[package]] name = "alloy-sol-macro-expander" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6da95adcf4760bb4b108fefa51d50096c5e5fdd29ee72fed3e86ee414f2e34" +checksum = "fd9899da7d011b4fe4c406a524ed3e3f963797dbc93b45479d60341d3a27b252" dependencies = [ "alloy-sol-macro-input", "const-hex", - "heck 0.4.1", + "heck", "indexmap 2.2.6", "proc-macro-error", "proc-macro2", @@ -142,13 +142,13 @@ dependencies = [ [[package]] name = "alloy-sol-macro-input" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c8da04c1343871fb6ce5a489218f9c85323c8340a36e9106b5fc98d4dd59d5" +checksum = "d32d595768fdc61331a132b6f65db41afae41b9b97d36c21eb1b955c422a7e60" dependencies = [ "const-hex", "dunce", - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.66", @@ -157,9 +157,9 @@ dependencies = [ [[package]] name = "alloy-sol-types" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a64d2d2395c1ac636b62419a7b17ec39031d6b2367e66e9acbf566e6055e9c" +checksum = "a49042c6d3b66a9fe6b2b5a8bf0d39fc2ae1ee0310a2a26ffedd79fb097878dd" dependencies = [ "alloy-primitives", "alloy-sol-macro", @@ -561,7 +561,7 @@ dependencies = [ "bitflags 2.5.0", "cexpr", "clang-sys", - "itertools 0.12.1", + "itertools 0.10.5", "lazy_static", "lazycell", "log", @@ -692,6 +692,12 @@ version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +[[package]] +name = "bytemuck" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" + [[package]] name = "byteorder" version = "1.5.0" @@ -747,9 +753,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.98" +version = "1.0.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41c270e7540d725e65ac7f1b212ac8ce349719624d7bcff99f8e2e488e8cf03f" +checksum = "96c51067fd44124faa7f870b4b1c969379ad32b2ba805aa959430ceaa384f695" dependencies = [ "jobserver", "libc", @@ -834,9 +840,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" dependencies = [ "clap_builder", "clap_derive", @@ -844,9 +850,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" dependencies = [ "anstream", "anstyle", @@ -856,11 +862,11 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.4" +version = "4.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.66", @@ -2083,12 +2089,6 @@ dependencies = [ "fxhash", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -2488,6 +2488,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -3594,9 +3603,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.84" +version = "1.0.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec96c6a92621310b51366f1e28d05ef11489516e93be030060e5fc12024a49d6" +checksum = "22244ce15aa966053a896d1accb3a6e68469b97c7f33f284b99f0d576879fc23" dependencies = [ "unicode-ident", ] @@ -3638,7 +3647,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81bddcdb20abf9501610992b6759a4c888aef7d1a7247ef75e2404275ac24af1" dependencies = [ "anyhow", - "itertools 0.12.1", + "itertools 0.10.5", "proc-macro2", "quote", "syn 2.0.66", @@ -3985,9 +3994,9 @@ dependencies = [ [[package]] name = "ruint" -version = "1.12.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f308135fef9fc398342da5472ce7c484529df23743fb7c734e0f3d472971e62" +checksum = "2c3cc4c2511671f327125da14133d0c5c5d137f006a1017a16f557bc85b16286" dependencies = [ "alloy-rlp", "ark-ff 0.3.0", @@ -4009,9 +4018,9 @@ dependencies = [ [[package]] name = "ruint-macro" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f86854cf50259291520509879a5c294c3c9a4c334e9ff65071c51e42ef1e2343" +checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" [[package]] name = "rustc-demangle" @@ -4598,6 +4607,7 @@ dependencies = [ "arrayref", "bincode", "blake3", + "bytemuck", "cfg-if", "criterion", "curve25519-dalek", @@ -4605,7 +4615,7 @@ dependencies = [ "elliptic-curve", "generic-array 1.0.0", "hex", - "itertools 0.12.1", + "itertools 0.13.0", "k256", "log", "nohash-hasher", @@ -4699,7 +4709,7 @@ dependencies = [ name = "sp1-primitives" version = "0.1.0" dependencies = [ - "itertools 0.12.1", + "itertools 0.13.0", "lazy_static", "p3-baby-bear", "p3-field", @@ -4714,18 +4724,21 @@ dependencies = [ "anyhow", "backtrace", "bincode", + "bytemuck", "clap", "dirs", "futures", "hex", "indicatif", - "itertools 0.12.1", + "itertools 0.13.0", "num-bigint 0.4.5", "p3-baby-bear", "p3-bn254-fr", "p3-challenger", "p3-commit", "p3-field", + "p3-util", + "rand", "rayon", "reqwest 0.12.4", "serde", @@ -4754,7 +4767,7 @@ version = "0.1.0" dependencies = [ "bincode", "ff 0.13.0", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4784,7 +4797,7 @@ name = "sp1-recursion-compiler" version = "0.1.0" dependencies = [ "backtrace", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4814,7 +4827,7 @@ dependencies = [ "backtrace", "ff 0.13.0", "hashbrown 0.14.5", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4828,6 +4841,7 @@ dependencies = [ "p3-merkle-tree", "p3-poseidon2", "p3-symmetric", + "p3-util", "rand", "serde", "serde_with", @@ -4852,16 +4866,22 @@ dependencies = [ name = "sp1-recursion-gnark-ffi" version = "0.1.0" dependencies = [ + "anyhow", + "bincode", "bindgen", "cc", "cfg-if", + "hex", "log", "num-bigint 0.4.5", "p3-baby-bear", "p3-field", + "p3-symmetric", "rand", "serde", "serde_json", + "sha2", + "sp1-core", "sp1-recursion-compiler", "tempfile", ] @@ -4870,7 +4890,7 @@ dependencies = [ name = "sp1-recursion-program" version = "0.1.0" dependencies = [ - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-challenger", @@ -4923,6 +4943,7 @@ dependencies = [ "strum", "strum_macros", "tempfile", + "thiserror", "tokio", "tracing", "twirp", @@ -4937,6 +4958,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "p3-baby-bear", @@ -4993,11 +5015,11 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "quote", "rustversion", @@ -5043,9 +5065,9 @@ dependencies = [ [[package]] name = "syn-solidity" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8db114c44cf843a8bacd37a146e37987a0b823a0e8bc4fdc610c9c72ab397a5" +checksum = "8d71e19bca02c807c9faa67b5a47673ff231b6e7449b251695188522f1dc44b2" dependencies = [ "paste", "proc-macro2", @@ -5209,9 +5231,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -5228,9 +5250,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", diff --git a/Dockerfile.gnark-ffi b/Dockerfile.gnark-ffi new file mode 100644 index 0000000000..3e0624d794 --- /dev/null +++ b/Dockerfile.gnark-ffi @@ -0,0 +1,33 @@ +FROM golang:1.22 AS go-builder + +FROM rustlang/rust:nightly-bullseye-slim AS rust-builder + +# Dependencies +RUN apt update && apt install -y clang + +# Install Go 1.22 +COPY --from=go-builder /usr/local/go /usr/local/go +ENV PATH="/usr/local/go/bin:$PATH" + +WORKDIR /sp1 + +# Install Rust toolchain +COPY ./rust-toolchain /sp1/rust-toolchain +RUN rustup show + +# Copy repo +COPY . /sp1 + +# Build the gnark-ffi CLI +WORKDIR /sp1/recursion/gnark-cli + +RUN \ + --mount=type=cache,target=target \ + cargo build --release && cp target/release/sp1-recursion-gnark-cli /gnark-cli + +FROM rustlang/rust:nightly-bullseye-slim +COPY --from=rust-builder /gnark-cli /gnark-cli + +LABEL org.opencontainers.image.source=https://github.com/succinctlabs/sp1 + +ENTRYPOINT ["/gnark-cli"] \ No newline at end of file diff --git a/book/SUMMARY.md b/book/SUMMARY.md index 3728ed9b0e..7e41b5b30b 100644 --- a/book/SUMMARY.md +++ b/book/SUMMARY.md @@ -32,14 +32,18 @@ - [Advanced](./generating-proofs/advanced.md) +# Prover Network + +- [Setup](./prover-network/setup.md) + +- [Usage](./prover-network/usage.md) + # Verifying Proofs - [Solidity & EVM](./verifying-proofs/solidity-and-evm.md) # Developers -- [Recommended Settings](./developers/recommended-settings.md) - - [Building Plonk Bn254 Artifacts](./developers/building-plonk-artifacts.md) -- [Common Issues](./developers/common-issues.md) \ No newline at end of file +- [Common Issues](./developers/common-issues.md) diff --git a/book/developers/recommended-settings.md b/book/developers/recommended-settings.md deleted file mode 100644 index 4c03948685..0000000000 --- a/book/developers/recommended-settings.md +++ /dev/null @@ -1,6 +0,0 @@ -# Recommended Settings - -For developers contributing to the SP1 project, we recommend the following settings: - -- `FRI_QUERIES=1`: Makes the prover use less bits of security to generate proofs more quickly. -- `SP1_DEV=1`: This will rebuild the Plonk Bn254 artifacts everytime they are necessary. \ No newline at end of file diff --git a/book/generating-proofs/network.md b/book/generating-proofs/network.md deleted file mode 100644 index 5f1e8d2c57..0000000000 --- a/book/generating-proofs/network.md +++ /dev/null @@ -1,47 +0,0 @@ -# Generating Proofs: Prover Network - -In the case that you do not want to prove locally, you can use the Succinct prover network to generate proofs. - -**Note:** The network is still in development and should be only used for testing purposes. - -## Sending a proof request - -To use the prover network to generate a proof, you can run your program as you would normally but with additional environment variables set: - -```sh -SP1_PROVER=network SP1_PRIVATE_KEY=... cargo run --release -``` - -- `SP1_PROVER` should be set to `network` when using the prover network. - -- `SP1_PRIVATE_KEY` is your secp256k1 private key for signing messages on the network. The balance of - the address corresponding to this private key will be used to pay for the proof request. - -Once a request is sent, a prover will claim the request and start generating a proof. After some -time, it will be returned. - -## Network balance - -Before sending requests, you must ensure you have enough balance on the network. You can add to your -balance by sending ETH to the canonical `NetworkFeeVault` contract on Base, which has the address -[0x66ea36fDBdDD09E3aCAB7B9f654220B00e537574](https://basescan.org/address/0x66ea36fdbddd09e3acab7b9f654220b00e537574#code). - -Adding to your balance can be done in [Etherscan](https://basescan.org/address/0x66ea36fdbddd09e3acab7b9f654220b00e537574#writeContract) by -connecting your wallet, or by using the [cast](https://book.getfoundry.sh/cast/) CLI tool. - -This can be done either by calling the `addBalance()` function: - -```sh -# The sender will send 1000 wei and the $OWNER will have their balance increased by 1000 -OWNER=(your address) -AMOUNT=1000 -cast send 0x66ea36fDBdDD09E3aCAB7B9f654220B00e537574 "addBalance(address)" $OWNER --value $AMOUNT --private-key $PRIVATE_KEY --chain-id 8453 --rpc-url https://developer-access-mainnet.base.org -``` - -or by sending ETH directly: - -```sh -# The sender will send 1000 wei and have their balance increased by 1000 -AMOUNT=1000 -cast send 0x66ea36fDBdDD09E3aCAB7B9f654220B00e537574 --value $AMOUNT --private-key $PRIVATE_KEY --chain-id 8453 --rpc-url https://developer-access-mainnet.base.org -``` diff --git a/book/getting-started/install.md b/book/getting-started/install.md index abd3b34a7f..f555f422bc 100644 --- a/book/getting-started/install.md +++ b/book/getting-started/install.md @@ -6,7 +6,8 @@ build the toolchain and CLI from source. ## Requirements - [Rust (Nightly)](https://www.rust-lang.org/tools/install) -- [Go >1.22.1](https://go.dev/doc/install) +- [Docker](https://docs.docker.com/get-docker/) +- [Go >1.22.1 (Optional)](https://go.dev/doc/install) ## Option 1: Prebuilt Binaries (Recommended) @@ -18,7 +19,7 @@ sp1up is the SP1 toolchain installer. Open your terminal and run the following c curl -L https://sp1.succinct.xyz | bash ``` -This will install sp1up, then simply follow the instructions on-screen, which will make the `sp1up` command available in your CLI. +Then simply follow the instructions on-screen, which will make the `sp1up` command available in your CLI. After following the instructions, you can run `sp1up` to install the toolchain: @@ -26,8 +27,10 @@ After following the instructions, you can run `sp1up` to install the toolchain: sp1up ``` -This will install support for the `riscv32im-succinct-zkvm-elf` compilation target within your Rust compiler -and a `cargo prove` CLI tool that will let you compile provable programs and then prove their correctness. +This will install two things: + +1. The `succinct` Rust toolchain which has support for the `riscv32im-succinct-zkvm-elf` compilation target. +2. `cargo prove` CLI tool that will let you compile provable programs and then prove their correctness. You can verify the installation by running `cargo prove --version`: diff --git a/book/getting-started/quickstart.md b/book/getting-started/quickstart.md index f5cdbbb7ad..86a2de7143 100644 --- a/book/getting-started/quickstart.md +++ b/book/getting-started/quickstart.md @@ -59,7 +59,7 @@ b: 332825110087067562321196029789634457848 successfully generated and verified proof for the program! ``` -The program by default is quite small, so proof generation will only take a few seconds locally. After it completes, the proof will be saved in the `proof-with-io.json` file and also be verified for correctness. +The program by default is quite small, so proof generation will only take a few seconds locally. After it completes, the proof will be saved in the `proof-with-io.bin` file and also be verified for correctness. ## Modifying the Program diff --git a/book/prover-network/explorer.png b/book/prover-network/explorer.png new file mode 100644 index 0000000000..9241510cd8 Binary files /dev/null and b/book/prover-network/explorer.png differ diff --git a/book/prover-network/setup.md b/book/prover-network/setup.md new file mode 100644 index 0000000000..44544cf689 --- /dev/null +++ b/book/prover-network/setup.md @@ -0,0 +1,35 @@ +# Prover Network: Setup + +So far we've explored how to generate proofs locally, but this can actually be inconvenient on local machines due to high memory / CPU requirements, especially for very large programs. + +Succinct [has been building](https://blog.succinct.xyz/succinct-network/) the Succinct Network, a distributed network of provers that can generate proofs of any size quickly and reliably. It's currently in private beta, but you can get access by following the steps below. + +## Get access + +Currently the network is permissioned, so you need to gain access through Succinct. After you have completed the key setup below, you can submit your address in this [form](https://docs.google.com/forms/d/e/1FAIpQLSd-X9uH7G0bvXH_kjptnQtNil8L4dumrVPpFE4t8Ci1XT1GaQ/viewform) and we'll contact you shortly. + +### Key Setup + +The prover network uses secp256k1 keypairs for authentication, like Ethereum wallets. You may generate a new keypair explicitly for use with the prover network, or use an existing keypair. Currently you do not need to hold any funds in this account, it is used solely for access control. + +Prover network keypair credentials can be generated using the [cast](https://book.getfoundry.sh/cast/) CLI tool: + +[Install](https://book.getfoundry.sh/getting-started/installation#using-foundryup): + +```sh +curl -L https://foundry.paradigm.xyz | bash +``` + +Generate a new keypair: + +```sh +cast wallet new +``` + +Or, retrieve your address from an existing key: + +```sh +cast wallet address --private-key $PRIVATE_KEY +``` + +Make sure to keep your private key somewhere safe and secure, you'll need it to interact with the prover network. diff --git a/book/prover-network/usage.md b/book/prover-network/usage.md new file mode 100644 index 0000000000..edaf114944 --- /dev/null +++ b/book/prover-network/usage.md @@ -0,0 +1,63 @@ +# Prover Network: Usage + +## Sending a proof request + +To use the prover network to generate a proof, you can run your script that uses `sp1_sdk::ProverClient` as you would normally but with additional environment variables set: + +```rust,noplayground +// Generate the proof for the given program. +let client = ProverClient::new(); +let (pk, vk) = client.setup(ELF); +let mut proof = client.prove(&pk, stdin).unwrap(); +``` + +```sh +SP1_PROVER=network SP1_PRIVATE_KEY=... RUST_LOG=info cargo run --release +``` + +- `SP1_PROVER` should be set to `network` when using the prover network. + +- `SP1_PRIVATE_KEY` should be set to your [private key](#key-setup). You will need + to be using a [permissioned](#get-access) key to use the network. + +When you call any of the prove functions in ProverClient now, it will first simulate your program, then wait for it to be proven through the network and finally return the proof. + +## View the status of your proof + +You can view your proof and other running proofs on the [explorer](https://explorer.succinct.xyz/). The page for your proof will show details such as the stage of your proof and the cycles used. It also shows the program hash which is the keccak256 of the program bytes. + +![Screenshot from explorer.succinct.xyz showing the details of a proof including status, stage, type, program, requester, prover, CPU cycles used, time requested, and time claimed.](explorer.png) + +## Advanced Usage + +### Skip simulation + +To skip the simulation step and directly submit the program for proof generation, you can set the `SKIP_SIMULATION` environment variable to `true`. This will save some time if you are sure that your program is correct. If your program panics, the proof will fail and ProverClient will panic. + +### Use NetworkProver directly + +By using the `sp1_sdk::NetworkProver` struct directly, you can call async functions directly and have programmatic access to the proof ID. + +```rust,noplayground +impl NetworkProver { + /// Creates a new [NetworkProver] with the private key set in `SP1_PRIVATE_KEY`. + pub fn new() -> Self; + + /// Creates a new [NetworkProver] with the given private key. + pub fn new_from_key(private_key: &str) -> Self; + + /// Requests a proof from the prover network, returning the proof ID. + pub async fn request_proof( + &self, + elf: &[u8], + stdin: SP1Stdin, + mode: ProofMode, + ) -> Result; + + /// Waits for a proof to be generated and returns the proof. + pub async fn wait_proof(&self, proof_id: &str) -> Result

; + + /// Requests a proof from the prover network and waits for it to be generated. + pub async fn prove(&self, elf: &[u8], stdin: SP1Stdin) -> Result

; +} +``` diff --git a/book/verifying-proofs/solidity-and-evm.md b/book/verifying-proofs/solidity-and-evm.md index e75c988f14..e282a73cff 100644 --- a/book/verifying-proofs/solidity-and-evm.md +++ b/book/verifying-proofs/solidity-and-evm.md @@ -7,12 +7,7 @@ of using SP1 for on-chain usecases, refer to the [SP1 Project Template](https:// By default, the proofs generated by SP1 are not verifiable onchain, as they are non-constant size and STARK verification on Ethereum is very expensive. To generate a proof that can be verified onchain, we use performant STARK recursion to combine SP1 shard proofs into a single STARK proof and then wrap that in a SNARK proof. Our `ProverClient` has a function for this called `prove_plonk`. Behind the scenes, this function will first generate a normal SP1 proof, then recursively combine all of them into a single proof using the STARK recursion protocol. Finally, the proof is wrapped in a SNARK proof using PLONK. -**The PLONK Bn254 prover is only guaranteed to work on official releases of SP1.** - -To use PLONK proving & verification locally, enable the `plonk` feature flag in the sp1-sdk and ensure that Go >1.22.1 is installed. -```toml -sp1-sdk = { features = ["plonk"] } -``` +**The PLONK Bn254 prover is only guaranteed to work on official releases of SP1. To use PLONK proving & verification locally, ensure that you have Docker installed.** ### Example @@ -22,16 +17,23 @@ sp1-sdk = { features = ["plonk"] } You can run the above script with `RUST_LOG=info cargo run --bin plonk_bn254 --release` in `examples/fibonacci/script`. -## Install SP1 Contracts +### Advanced: PLONK without Docker -# SP1 Contracts +If you would like to run the PLONK prover directly without Docker, you must have Go 1.22 installed and enable the `native-plonk` feature in `sp1-sdk`. This path is not recommended and may require additional native dependencies. + +```toml +sp1-sdk = { features = ["native-plonk"] } +``` + +# Install SP1 Contracts + +## SP1 Contracts This repository contains the smart contracts for verifying [SP1](https://github.com/succinctlabs/sp1) EVM proofs. ## Installation -> [!WARNING] -> [Foundry](https://github.com/foundry-rs/foundry) installs the latest release version initially, but subsequent `forge update` commands will use the `main` branch. This branch is the development branch and should be avoided in favor of tagged releases. The release process matches a specific SP1 version. +> [!WARNING] > [Foundry](https://github.com/foundry-rs/foundry) installs the latest release version initially, but subsequent `forge update` commands will use the `main` branch. This branch is the development branch and should be avoided in favor of tagged releases. The release process matches a specific SP1 version. To install the latest release version: @@ -40,6 +42,7 @@ forge install succinctlabs/sp1-contracts ``` To install a specific version: + ```bash forge install succinctlabs/sp1-contracts@ ``` @@ -51,7 +54,7 @@ Add `@sp1-contracts/=lib/sp1-contracts/contracts/src/` in `remappings.txt.` Once installed, you can use the contracts in the library by importing them: ```solidity -pragma solidity ^0.8.25; +pragma solidity ^0.8.19; import {SP1Verifier} from "@sp1-contracts/SP1Verifier.sol"; @@ -59,4 +62,4 @@ contract MyContract is SP1Verifier { } ``` -For more details on the contracts, refer to the [sp1-contracts](https://github.com/succinctlabs/sp1-contracts) repo. \ No newline at end of file +For more details on the contracts, refer to the [sp1-contracts](https://github.com/succinctlabs/sp1-contracts) repo. diff --git a/book/writing-programs/setup.md b/book/writing-programs/setup.md index ab87e06044..6cca74e9db 100644 --- a/book/writing-programs/setup.md +++ b/book/writing-programs/setup.md @@ -11,7 +11,7 @@ cargo prove new cd program ``` -#### Build +### Build To build the program, simply run: @@ -21,6 +21,13 @@ cargo prove build This will compile the ELF that can be executed in the zkVM and put the executable in `elf/riscv32im-succinct-zkvm-elf`. +### Build with Docker + +Another option is to build your program in a Docker container. This is useful if you are on a platform that does not have prebuilt binaries for the succinct toolchain, or if you are looking to get a reproducible ELF output. To do so, just use the `--docker` flag. + +``` +cargo prove build --docker +``` ## Manual @@ -31,7 +38,7 @@ cargo new program cd program ``` -#### Cargo Manifest +### Cargo Manifest Inside this crate, add the `sp1-zkvm` crate as a dependency. Your `Cargo.toml` should look like as follows: @@ -49,7 +56,7 @@ sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git" } The `sp1-zkvm` crate includes necessary utilities for your program, including handling inputs and outputs, precompiles, patches, and more. -#### main.rs +### main.rs Inside the `src/main.rs` file, you must make sure to include these two lines to ensure that the crate properly compiles. @@ -61,8 +68,7 @@ sp1_zkvm::entrypoint!(main); These two lines of code wrap your main function with some additional logic to ensure that your program compiles correctly with the RISCV target. - -#### Build +### Build To build the program, simply run: @@ -70,4 +76,4 @@ To build the program, simply run: cargo prove build ``` -This will compile the ELF (RISCV binary) that can be executed in the zkVM and put the executable in `elf/riscv32im-succinct-zkvm-elf`. \ No newline at end of file +This will compile the ELF (RISCV binary) that can be executed in the zkVM and put the executable in `elf/riscv32im-succinct-zkvm-elf`. diff --git a/cli/Cargo.toml b/cli/Cargo.toml index 6831879e6e..12583a5141 100644 --- a/cli/Cargo.toml +++ b/cli/Cargo.toml @@ -13,7 +13,7 @@ vergen = { version = "8", default-features = false, features = [ [dependencies] anyhow = { version = "1.0.83", features = ["backtrace"] } cargo_metadata = "0.18.1" -clap = { version = "4.5.4", features = ["derive", "env"] } +clap = { version = "4.5.7", features = ["derive", "env"] } sp1-prover = { path = "../prover" } sp1-sdk = { path = "../sdk" } sp1-core = { path = "../core" } diff --git a/cli/docker/Dockerfile b/cli/docker/Dockerfile index 1f0258599f..b8b35b3467 100644 --- a/cli/docker/Dockerfile +++ b/cli/docker/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:24.04@sha256:3f85b7caad41a95462cf5b787d8a04604c8262cdcdf9a472b8c52ef83375fe15 +FROM ubuntu:24.04@sha256:e3f92abc0967a6c19d0dfa2d55838833e947b9d74edbcb0113e48535ad4be12a RUN apt-get update RUN apt-get install -y --no-install-recommends ca-certificates clang curl libssl-dev pkg-config git dialog diff --git a/cli/src/assets/.gitignore b/cli/src/assets/.gitignore deleted file mode 100644 index 0be6e7269a..0000000000 --- a/cli/src/assets/.gitignore +++ /dev/null @@ -1,16 +0,0 @@ -# Cargo build -**/target - -# Cargo config -.cargo - -# Profile-guided optimization -/tmp -pgo-data.profdata - -# MacOS nuisances -.DS_Store - -# Proofs -**/proof-with-pis.json -**/proof-with-io.json diff --git a/cli/src/assets/.vscode/settings.json b/cli/src/assets/.vscode/settings.json deleted file mode 100644 index d43ccf8048..0000000000 --- a/cli/src/assets/.vscode/settings.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "editor.inlineSuggest.enabled": true, - "[rust]": { - "editor.defaultFormatter": "rust-lang.rust-analyzer", - }, - "editor.rulers": [ - 100 - ], - "rust-analyzer.check.overrideCommand": [ - "cargo", - "clippy", - "--workspace", - "--message-format=json", - "--all-features", - "--all-targets", - "--", - "-A", - "incomplete-features" - ], - "rust-analyzer.linkedProjects": [ - "program/Cargo.toml", - "script/Cargo.toml", - ], - "rust-analyzer.showUnlinkedFileNotification": false -} \ No newline at end of file diff --git a/cli/src/assets/program/Cargo.toml b/cli/src/assets/program/Cargo.toml deleted file mode 100644 index 793baed10a..0000000000 --- a/cli/src/assets/program/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[workspace] -[package] -version = "0.1.0" -name = "unnamed-program" -edition = "2021" - -[dependencies] -sp1-zkvm = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" } diff --git a/cli/src/assets/program/main.rs b/cli/src/assets/program/main.rs deleted file mode 100644 index 8bcb12df9d..0000000000 --- a/cli/src/assets/program/main.rs +++ /dev/null @@ -1,22 +0,0 @@ -//! A simple program to be proven inside the zkVM. - -#![no_main] -sp1_zkvm::entrypoint!(main); - -pub fn main() { - // NOTE: values of n larger than 186 will overflow the u128 type, - // resulting in output that doesn't match fibonacci sequence. - // However, the resulting proof will still be valid! - let n = sp1_zkvm::io::read::(); - let mut a: u128 = 0; - let mut b: u128 = 1; - let mut sum: u128; - for _ in 1..n { - sum = a + b; - a = b; - b = sum; - } - - sp1_zkvm::io::commit(&a); - sp1_zkvm::io::commit(&b); -} diff --git a/cli/src/assets/script/Cargo.toml b/cli/src/assets/script/Cargo.toml deleted file mode 100644 index 8efd51b9ab..0000000000 --- a/cli/src/assets/script/Cargo.toml +++ /dev/null @@ -1,11 +0,0 @@ -[workspace] -[package] -version = "0.1.0" -name = "unnamed-script" -edition = "2021" - -[dependencies] -sp1-sdk = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" } - -[build-dependencies] -sp1-helper = { git = "https://github.com/succinctlabs/sp1.git", branch = "main" } diff --git a/cli/src/assets/script/build.rs b/cli/src/assets/script/build.rs deleted file mode 100644 index 03388acab7..0000000000 --- a/cli/src/assets/script/build.rs +++ /dev/null @@ -1,3 +0,0 @@ -fn main() { - sp1_helper::build_program(&format!("{}/../program", env!("CARGO_MANIFEST_DIR"))); -} diff --git a/cli/src/assets/script/main.rs b/cli/src/assets/script/main.rs deleted file mode 100644 index 2c699b86ac..0000000000 --- a/cli/src/assets/script/main.rs +++ /dev/null @@ -1,37 +0,0 @@ -//! A simple script to generate and verify the proof of a given program. - -use sp1_sdk::{utils, ProverClient, SP1Stdin}; - -/// The ELF we want to execute inside the zkVM. -const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); - -fn main() { - // Setup logging. - utils::setup_logger(); - - // Generate proof. - let mut stdin = SP1Stdin::new(); - let n = 186u32; - stdin.write(&n); - let client = ProverClient::new(); - let (pk, vk) = client.setup(ELF); - let mut proof = client.prove_compressed(&pk, stdin).expect("proving failed"); - - // Read output. - let a = proof.public_values.read::(); - let b = proof.public_values.read::(); - println!("a: {}", a); - println!("b: {}", b); - - // Verify proof. - client - .verify_compressed(&proof, &vk) - .expect("verification failed"); - - // Save proof. - proof - .save("proof-with-io.json") - .expect("saving proof failed"); - - println!("successfully generated and verified proof for the program!") -} diff --git a/cli/src/assets/script/rust-toolchain b/cli/src/assets/script/rust-toolchain deleted file mode 100644 index 3a306210c9..0000000000 --- a/cli/src/assets/script/rust-toolchain +++ /dev/null @@ -1,3 +0,0 @@ -[toolchain] -channel = "nightly-2024-04-17" -components = ["llvm-tools", "rustc-dev"] \ No newline at end of file diff --git a/cli/src/commands/install_toolchain.rs b/cli/src/commands/install_toolchain.rs index 5026c26e4f..b795e33ab1 100644 --- a/cli/src/commands/install_toolchain.rs +++ b/cli/src/commands/install_toolchain.rs @@ -11,7 +11,9 @@ use std::process::Command; #[cfg(target_family = "unix")] use std::os::unix::fs::PermissionsExt; -use crate::{get_target, get_toolchain_download_url, url_exists, RUSTUP_TOOLCHAIN_NAME}; +use crate::{ + get_target, get_toolchain_download_url, is_supported_target, url_exists, RUSTUP_TOOLCHAIN_NAME, +}; #[derive(Parser)] #[command( @@ -58,6 +60,11 @@ impl InstallToolchainCmd { Ok(_) => println!("Successfully created ~/.sp1 directory."), Err(err) => println!("Failed to create ~/.sp1 directory: {}", err), }; + + assert!( + is_supported_target(), + "Unsupported architecture. Please build the toolchain from source." + ); let target = get_target(); let toolchain_asset_name = format!("rust-toolchain-{}.tar.gz", target); let toolchain_archive_path = root_dir.join(toolchain_asset_name.clone()); diff --git a/cli/src/commands/new.rs b/cli/src/commands/new.rs index d001a36bf3..ae61d55866 100644 --- a/cli/src/commands/new.rs +++ b/cli/src/commands/new.rs @@ -1,62 +1,49 @@ use anyhow::Result; use clap::Parser; -use std::{fs, path::Path}; +use std::{fs, path::Path, process::Command}; use yansi::Paint; -const PROGRAM_CARGO_TOML: &str = include_str!("../assets/program/Cargo.toml"); -const PROGRAM_MAIN_RS: &str = include_str!("../assets/program/main.rs"); -const SCRIPT_CARGO_TOML: &str = include_str!("../assets/script/Cargo.toml"); -const SCRIPT_MAIN_RS: &str = include_str!("../assets/script/main.rs"); -const SCRIPT_RUST_TOOLCHAIN: &str = include_str!("../assets/script/rust-toolchain"); -const SCRIPT_BUILD_RS: &str = include_str!("../assets/script/build.rs"); -const GIT_IGNORE: &str = include_str!("../assets/.gitignore"); -const VS_CODE_SETTINGS_JSON: &str = include_str!("../assets/.vscode/settings.json"); - #[derive(Parser)] #[command(name = "new", about = "Setup a new project that runs inside the SP1.")] pub struct NewCmd { name: String, } +const TEMPLATE_REPOSITORY_URL: &str = "https://github.com/succinctlabs/sp1-project-template"; + impl NewCmd { pub fn run(&self) -> Result<()> { let root = Path::new(&self.name); - let program_root = root.join("program"); - let script_root = root.join("script"); - - // Create the root directory. - fs::create_dir(&self.name)?; - - // Create the program directory. - fs::create_dir(&program_root)?; - fs::create_dir(program_root.join("src"))?; - fs::create_dir(program_root.join("elf"))?; - fs::write( - program_root.join("Cargo.toml"), - PROGRAM_CARGO_TOML.replace("unnamed", &self.name), - )?; - fs::write(program_root.join("src").join("main.rs"), PROGRAM_MAIN_RS)?; - - // Create the runner directory. - fs::create_dir(&script_root)?; - fs::create_dir(script_root.join("src"))?; - fs::write( - script_root.join("Cargo.toml"), - SCRIPT_CARGO_TOML.replace("unnamed", &self.name), - )?; - fs::write(script_root.join("src").join("main.rs"), SCRIPT_MAIN_RS)?; - fs::write(script_root.join("rust-toolchain"), SCRIPT_RUST_TOOLCHAIN)?; - fs::write(script_root.join("build.rs"), SCRIPT_BUILD_RS)?; - - // Add .gitignore file to root. - fs::write(root.join(".gitignore"), GIT_IGNORE)?; - // Add .vscode/settings.json to root. - fs::create_dir(root.join(".vscode"))?; - fs::write( - root.join(".vscode").join("settings.json"), - VS_CODE_SETTINGS_JSON, - )?; + // Create the root directory if it doesn't exist. + if !root.exists() { + fs::create_dir(&self.name)?; + } + + // Clone the repository. + let output = Command::new("git") + .arg("clone") + .arg(TEMPLATE_REPOSITORY_URL) + .arg(root.as_os_str()) + .arg("--recurse-submodules") + .arg("--depth=1") + .output() + .expect("failed to execute command"); + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(anyhow::anyhow!("failed to clone repository: {}", stderr)); + } + + // Remove the .git directory. + fs::remove_dir_all(root.join(".git"))?; + + // Check if the user has `foundry` installed. + if Command::new("foundry").arg("--version").output().is_err() { + println!( + " \x1b[1m{}\x1b[0m Make sure to install Foundry to use contracts: https://book.getfoundry.sh/getting-started/installation.", + Paint::yellow("Warning:"), + ); + } println!( " \x1b[1m{}\x1b[0m {} ({})", diff --git a/cli/src/lib.rs b/cli/src/lib.rs index 87fd0c0c72..1ad6fbacee 100644 --- a/cli/src/lib.rs +++ b/cli/src/lib.rs @@ -37,6 +37,23 @@ pub async fn url_exists(client: &Client, url: &str) -> bool { res.is_ok() } +#[allow(unreachable_code)] +pub fn is_supported_target() -> bool { + #[cfg(all(target_arch = "x86_64", target_os = "linux"))] + return true; + + #[cfg(all(target_arch = "aarch64", target_os = "linux"))] + return true; + + #[cfg(all(target_arch = "x86_64", target_os = "macos"))] + return true; + + #[cfg(all(target_arch = "aarch64", target_os = "macos"))] + return true; + + false +} + pub fn get_target() -> String { target_lexicon::HOST.to_string() } diff --git a/core/Cargo.toml b/core/Cargo.toml index df02550f1a..85ad856103 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -7,7 +7,7 @@ version = "0.1.0" bincode = "1.3.3" serde = { version = "1.0", features = ["derive", "rc"] } elf = "0.7.4" -itertools = "0.12.1" +itertools = "0.13.0" log = "0.4.21" nohash-hasher = "0.2.0" num = { version = "0.4.3" } @@ -58,6 +58,8 @@ web-time = "1.1.0" rayon-scan = "0.1.1" thiserror = "1.0.60" num-bigint = { version = "0.4.3", default-features = false } +rand = "0.8.5" +bytemuck = "1.16.0" [dev-dependencies] tiny-keccak = { version = "2.0.2", features = ["keccak"] } @@ -67,8 +69,9 @@ rand = "0.8.5" sp1-zkvm = { path = "../zkvm/entrypoint" } [features] -debug = [] neon = ["p3-blake3/neon"] +programs = [] +debug = [] [[bench]] harness = false diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index eba567d920..d957f43cd9 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -308,6 +308,7 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, + nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -316,6 +317,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) + .chain(once(nonce.into())) .collect(); self.send(AirInteraction::new( @@ -335,6 +337,7 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, + nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -343,6 +346,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) + .chain(once(nonce.into())) .collect(); self.receive(AirInteraction::new( @@ -359,6 +363,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, + nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -369,6 +374,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), + nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), @@ -385,6 +391,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, + nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -395,6 +402,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), + nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 2321427c53..3179d4d775 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -1,7 +1,8 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -38,6 +39,9 @@ pub struct AddSubCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations. /// It's result will be `a` for the add operation and `b` for the sub operation. pub add_operation: AddOperation, @@ -129,6 +133,13 @@ impl MachineAir for AddSubChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut AddSubCols = + trace.values[i * NUM_ADD_SUB_COLS..(i + 1) * NUM_ADD_SUB_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -151,6 +162,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &AddSubCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &AddSubCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Evaluate the addition operation. AddOperation::::eval( @@ -172,6 +191,7 @@ where local.operand_2, local.shard, local.channel, + local.nonce, local.is_add, ); @@ -183,6 +203,7 @@ where local.operand_2, local.shard, local.channel, + local.nonce, local.is_sub, ); diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 3e7227b709..81163b11e1 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -1,7 +1,9 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; +use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -31,6 +33,9 @@ pub struct BitwiseCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -111,6 +116,12 @@ impl MachineAir for BitwiseChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + for i in 0..trace.height() { + let cols: &mut BitwiseCols = + trace.values[i * NUM_BITWISE_COLS..(i + 1) * NUM_BITWISE_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -133,6 +144,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &BitwiseCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &BitwiseCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Get the opcode for the operation. let opcode = local.is_xor * ByteOpcode::XOR.as_field::() @@ -166,6 +185,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_xor + local.is_or + local.is_and, ); diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index b54206469f..84198b16bb 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -64,6 +64,7 @@ mod utils; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use std::collections::HashMap; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::AbstractField; @@ -72,11 +73,10 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; -use self::utils::eval_abs_value; use crate::air::MachineAir; use crate::air::{SP1AirBuilder, Word}; use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation}; -use crate::alu::AluEvent; +use crate::alu::{create_alu_lookups, AluEvent}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::disassembler::WORD_SIZE; @@ -107,6 +107,9 @@ pub struct DivRemCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -184,6 +187,23 @@ pub struct DivRemCols { /// Flag to indicate whether `c` is negative. pub c_neg: T, + /// The lower nonce of the operation. + pub lower_nonce: T, + + /// The upper nonce of the operation. + pub upper_nonce: T, + + /// The absolute nonce of the operation. + pub abs_nonce: T, + + /// Selector to determine whether an ALU Event is sent for absolute value computation of `c`. + pub abs_c_alu_event: T, + pub abs_c_alu_event_nonce: T, + + /// Selector to determine whether an ALU Event is sent for absolute value computation of `rem`. + pub abs_rem_alu_event: T, + pub abs_rem_alu_event_nonce: T, + /// Selector to know whether this row is enabled. pub is_real: T, @@ -259,6 +279,24 @@ impl MachineAir for DivRemChip { cols.max_abs_c_or_1 = Word::from(u32::max(1, event.c)); } + // Set the `alu_event` flags. + cols.abs_c_alu_event = cols.c_neg * cols.is_real; + cols.abs_c_alu_event_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[4]) + .copied() + .unwrap_or_default(), + ); + cols.abs_rem_alu_event = cols.rem_neg * cols.is_real; + cols.abs_rem_alu_event_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[5]) + .copied() + .unwrap_or_default(), + ); + // Insert the MSB lookup events. { let words = [event.b, event.c, remainder]; @@ -281,7 +319,7 @@ impl MachineAir for DivRemChip { // Calculate the modified multiplicity { - cols.remainder_check_multiplicity = cols.is_real * cols.is_c_0.result; + cols.remainder_check_multiplicity = cols.is_real * (F::one() - cols.is_c_0.result); } // Calculate c * quotient + remainder. @@ -321,6 +359,40 @@ impl MachineAir for DivRemChip { // mul and LT upon which div depends. This ordering is critical as mul and LT // require all the mul and LT events be added before we can call generate_trace. { + // Insert the absolute value computation events. + { + let mut add_events: Vec = vec![]; + if cols.abs_c_alu_event == F::one() { + add_events.push(AluEvent { + lookup_id: event.sub_lookups[4], + shard: event.shard, + channel: event.channel, + clk: event.clk, + opcode: Opcode::ADD, + a: 0, + b: event.c, + c: (event.c as i32).abs() as u32, + sub_lookups: create_alu_lookups(), + }) + } + if cols.abs_rem_alu_event == F::one() { + add_events.push(AluEvent { + lookup_id: event.sub_lookups[5], + shard: event.shard, + channel: event.channel, + clk: event.clk, + opcode: Opcode::ADD, + a: 0, + b: remainder, + c: (remainder as i32).abs() as u32, + sub_lookups: create_alu_lookups(), + }) + } + let mut alu_events = HashMap::new(); + alu_events.insert(Opcode::ADD, add_events); + output.add_alu_events(alu_events); + } + let mut lower_word = 0; for i in 0..WORD_SIZE { lower_word += (c_times_quotient[i] as u32) << (i * BYTE_SIZE); @@ -332,6 +404,7 @@ impl MachineAir for DivRemChip { } let lower_multiplication = AluEvent { + lookup_id: event.sub_lookups[0], shard: event.shard, channel: event.channel, clk: event.clk, @@ -339,10 +412,19 @@ impl MachineAir for DivRemChip { a: lower_word, c: event.c, b: quotient, + sub_lookups: create_alu_lookups(), }; + cols.lower_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[0]) + .copied() + .unwrap_or_default(), + ); output.add_mul_event(lower_multiplication); let upper_multiplication = AluEvent { + lookup_id: event.sub_lookups[1], shard: event.shard, channel: event.channel, clk: event.clk, @@ -356,22 +438,45 @@ impl MachineAir for DivRemChip { a: upper_word, c: event.c, b: quotient, + sub_lookups: create_alu_lookups(), }; - + cols.upper_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[1]) + .copied() + .unwrap_or_default(), + ); output.add_mul_event(upper_multiplication); - let lt_event = if is_signed_operation(event.opcode) { + cols.abs_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[2]) + .copied() + .unwrap_or_default(), + ); AluEvent { + lookup_id: event.sub_lookups[2], shard: event.shard, channel: event.channel, - opcode: Opcode::SLT, + opcode: Opcode::SLTU, a: 1, b: (remainder as i32).abs() as u32, c: u32::max(1, (event.c as i32).abs() as u32), clk: event.clk, + sub_lookups: create_alu_lookups(), } } else { + cols.abs_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[3]) + .copied() + .unwrap_or_default(), + ); AluEvent { + lookup_id: event.sub_lookups[3], shard: event.shard, channel: event.channel, opcode: Opcode::SLTU, @@ -379,8 +484,10 @@ impl MachineAir for DivRemChip { b: remainder, c: u32::max(1, event.c), clk: event.clk, + sub_lookups: create_alu_lookups(), } }; + if cols.remainder_check_multiplicity == F::one() { output.add_lt_event(lt_event); } @@ -430,6 +537,13 @@ impl MachineAir for DivRemChip { trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS]; } + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut DivRemCols = + trace.values[i * NUM_DIVREM_COLS..(i + 1) * NUM_DIVREM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -452,10 +566,18 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &DivRemCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &DivRemCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let one: AB::Expr = AB::F::one().into(); let zero: AB::Expr = AB::F::zero().into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Calculate whether b, remainder, and c are negative. { // Negative if and only if op code is signed & MSB = 1. @@ -490,6 +612,7 @@ where local.c, local.shard, local.channel, + local.lower_nonce, local.is_real, ); @@ -515,6 +638,7 @@ where local.c, local.shard, local.channel, + local.upper_nonce, local.is_real, ); } @@ -659,18 +783,37 @@ where // Range check remainder. (i.e., |remainder| < |c| when not is_c_0) { - eval_abs_value( - builder, - local.remainder.borrow(), - local.abs_remainder.borrow(), - local.rem_neg.borrow(), + // For each of `c` and `rem`, assert that the absolute value is equal to the original value, + // if the original value is non-negative or the minimum i32. + for i in 0..WORD_SIZE { + builder + .when_not(local.c_neg) + .assert_eq(local.c[i], local.abs_c[i]); + builder + .when_not(local.rem_neg) + .assert_eq(local.remainder[i], local.abs_remainder[i]); + } + // In the case that `c` or `rem` is negative, instead check that their sum is zero by + // sending an AddEvent. + builder.send_alu( + AB::Expr::from_canonical_u32(Opcode::ADD as u32), + Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), + local.c, + local.abs_c, + local.shard, + local.channel, + local.abs_c_alu_event_nonce, + local.abs_c_alu_event, ); - - eval_abs_value( - builder, - local.c.borrow(), - local.abs_c.borrow(), - local.c_neg.borrow(), + builder.send_alu( + AB::Expr::from_canonical_u32(Opcode::ADD as u32), + Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), + local.remainder, + local.abs_remainder, + local.shard, + local.channel, + local.abs_rem_alu_event_nonce, + local.abs_rem_alu_event, ); // max(abs(c), 1) = abs(c) * (1 - is_c_0) + 1 * is_c_0 @@ -691,29 +834,31 @@ where builder.assert_eq(local.max_abs_c_or_1[i], max_abs_c_or_1[i].clone()); } - let opcode = { - let is_signed = local.is_div + local.is_rem; - let is_unsigned = local.is_divu + local.is_remu; - let slt = AB::Expr::from_canonical_u32(Opcode::SLT as u32); - let sltu = AB::Expr::from_canonical_u32(Opcode::SLTU as u32); - is_signed * slt + is_unsigned * sltu - }; - - // Check that the event multiplicity column is computed correctly. + // Handle cases: + // - If is_real == 0 then remainder_check_multiplicity == 0 is forced. + // - If is_real == 1 then is_c_0_result must be the expected one, so + // remainder_check_multiplicity = (1 - is_c_0_result) * is_real. builder.assert_eq( + (AB::Expr::one() - local.is_c_0.result) * local.is_real, local.remainder_check_multiplicity, - local.is_c_0.result * local.is_real, ); + // the cleaner idea is simply remainder_check_multiplicity == (1 - is_c_0_result) * is_real + + // Check that the absolute value selector columns are computed correctly. + builder.assert_eq(local.abs_c_alu_event, local.c_neg * local.is_real); + builder.assert_eq(local.abs_rem_alu_event, local.rem_neg * local.is_real); + // Dispatch abs(remainder) < max(abs(c), 1), this is equivalent to abs(remainder) < // abs(c) if not division by 0. builder.send_alu( - opcode, + AB::Expr::from_canonical_u32(Opcode::SLTU as u32), Word([one.clone(), zero.clone(), zero.clone(), zero.clone()]), local.abs_remainder, local.max_abs_c_or_1, local.shard, local.channel, + local.abs_nonce, local.remainder_check_multiplicity, ); } @@ -783,6 +928,8 @@ where local.rem_neg, local.c_neg, local.is_real, + local.abs_c_alu_event, + local.abs_rem_alu_event, ]; for flag in bool_flags.iter() { @@ -817,6 +964,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/divrem/utils.rs b/core/src/alu/divrem/utils.rs index f3a7b7070f..d71c35aad6 100644 --- a/core/src/alu/divrem/utils.rs +++ b/core/src/alu/divrem/utils.rs @@ -1,7 +1,3 @@ -use p3_air::AirBuilder; -use p3_field::AbstractField; - -use crate::air::{SP1AirBuilder, Word, WORD_SIZE}; use crate::runtime::Opcode; /// Returns `true` if the given `opcode` is a signed operation. @@ -32,47 +28,3 @@ pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32) pub const fn get_msb(a: u32) -> u8 { ((a >> 31) & 1) as u8 } - -/// Verifies that `abs_value = abs(value)` using `is_negative` as a flag. -/// -/// `abs(value) + value = 0` if `value` is negative. `abs(value) = value` otherwise. -/// -/// In two's complement arithmetic, the negation involves flipping its bits and adding 1. Therefore, -/// for a negative number, `abs(value) + value` equals 0. This is because `abs(value)` is the two's -/// complement (negation) of `value`. For a positive number, `abs(value)` is the same as `value`. -/// -/// The function iterates over each limb of the `value` and `abs_value`, checking the following -/// conditions: -/// -/// 1. If `value` is non-negative, it checks that each limb in `value` and `abs_value` is identical. -/// 2. If `value` is negative, it checks that the sum of each corresponding limb in `value` and -/// `abs_value` equals the expected sum for a two's complement representation. The least -/// significant limb (first limb) should add up to `0xff + 1` (to account for the +1 in two's -/// complement negation), and other limbs should add up to `0xff` (as the rest of the limbs just -/// have their bits flipped). -pub fn eval_abs_value( - builder: &mut AB, - value: &Word, - abs_value: &Word, - is_negative: &AB::Var, -) where - AB: SP1AirBuilder, -{ - for i in 0..WORD_SIZE { - let exp_sum_if_negative = AB::Expr::from_canonical_u32({ - if i == 0 { - 0xff + 1 - } else { - 0xff - } - }); - - builder - .when(*is_negative) - .assert_eq(value[i] + abs_value[i], exp_sum_if_negative.clone()); - - builder - .when_not(*is_negative) - .assert_eq(value[i], abs_value[i]); - } -} diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 91b504181c..54d5768c2c 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -34,6 +34,9 @@ pub struct LtCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// If the opcode is SLT. pub is_slt: T, @@ -220,6 +223,13 @@ impl MachineAir for LtChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut LtCols = + trace.values[i * NUM_LT_COLS..(i + 1) * NUM_LT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -242,6 +252,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &LtCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &LtCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let is_real = local.is_slt + local.is_sltu; @@ -431,6 +449,7 @@ where local.c, local.shard, local.channel, + local.nonce, is_real, ); } diff --git a/core/src/alu/mod.rs b/core/src/alu/mod.rs index c667c612c8..a67d1ff909 100644 --- a/core/src/alu/mod.rs +++ b/core/src/alu/mod.rs @@ -11,6 +11,7 @@ pub use bitwise::*; pub use divrem::*; pub use lt::*; pub use mul::*; +use rand::Rng; pub use sll::*; pub use sr::*; @@ -21,6 +22,9 @@ use crate::runtime::Opcode; /// A standard format for describing ALU operations that need to be proven. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct AluEvent { + /// The lookup id of the event. + pub lookup_id: usize, + /// The shard number, used for byte lookup table. pub shard: u32, @@ -41,12 +45,15 @@ pub struct AluEvent { // The second input operand. pub c: u32, + + pub sub_lookups: [usize; 6], } impl AluEvent { /// Creates a new `AluEvent`. pub fn new(shard: u32, channel: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { Self { + lookup_id: 0, shard, channel, clk, @@ -54,6 +61,24 @@ impl AluEvent { a, b, c, + sub_lookups: create_alu_lookups(), } } } + +pub fn create_alu_lookup_id() -> usize { + let mut rng = rand::thread_rng(); + rng.gen() +} + +pub fn create_alu_lookups() -> [usize; 6] { + let mut rng = rand::thread_rng(); + [ + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + ] +} diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index c30a59c4f4..1351e78c38 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -79,6 +79,9 @@ pub struct MulCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -270,6 +273,13 @@ impl MachineAir for MulChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut MulCols = + trace.values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -292,12 +302,20 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MulCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &MulCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let byte_mask = AB::F::from_canonical_u8(BYTE_MASK); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Calculate the MSBs. let (b_msb, c_msb) = { let msb_pairs = [ @@ -412,14 +430,6 @@ where .when(local.c_sign_extend) .assert_eq(local.c_msb, one.clone()); - // If the opcode doesn't allow sign extension for an operand, we must not extend their sign. - builder - .when(local.is_mul + local.is_mulhu) - .assert_zero(local.b_sign_extend + local.c_sign_extend); - builder - .when(local.is_mul + local.is_mulhsu + local.is_mulhsu) - .assert_zero(local.c_sign_extend); - // Calculate the opcode. let opcode = { // Exactly one of the op codes must be on. @@ -455,6 +465,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index d87ee780d6..b5a711542b 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -67,6 +67,9 @@ pub struct ShiftLeftCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -199,6 +202,12 @@ impl MachineAir for ShiftLeft { trace.values[i] = padded_row_template[i % NUM_SHIFT_LEFT_COLS]; } + for i in 0..trace.height() { + let cols: &mut ShiftLeftCols = + trace.values[i * NUM_SHIFT_LEFT_COLS..(i + 1) * NUM_SHIFT_LEFT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -221,11 +230,19 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftLeftCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &ShiftLeftCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let base: AB::Expr = AB::F::from_canonical_u32(1 << BYTE_SIZE).into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // We first "bit shift" and next we "byte shift". Then we compare the results with a. // Finally, we perform some misc checks. @@ -354,6 +371,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 8f9ea721e5..bd7a91d52e 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -85,6 +85,9 @@ pub struct ShiftRightCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -283,6 +286,13 @@ impl MachineAir for ShiftRightChip { trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; } + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShiftRightCols = + trace.values[i * NUM_SHIFT_RIGHT_COLS..(i + 1) * NUM_SHIFT_RIGHT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -305,9 +315,17 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftRightCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &ShiftRightCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Check that the MSB of most_significant_byte matches local.b_msb using lookup. { let byte = local.b[WORD_SIZE - 1]; @@ -464,6 +482,9 @@ where for shift_by_n_bit in local.shift_by_n_bits.iter() { builder.assert_bool(*shift_by_n_bit); } + for bit in local.c_least_sig_byte.iter() { + builder.assert_bool(*bit); + } } // Range check bytes. @@ -485,6 +506,9 @@ where builder.assert_bool(local.is_sra); builder.assert_bool(local.is_real); + // Check that is_real is the sum of the two operation flags. + builder.assert_eq(local.is_srl + local.is_sra, local.is_real); + // Receive the arguments. builder.receive_alu( local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32) @@ -494,6 +518,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/bytes/mod.rs b/core/src/bytes/mod.rs index f6d5bc482c..2cf8c8fb1c 100644 --- a/core/src/bytes/mod.rs +++ b/core/src/bytes/mod.rs @@ -24,7 +24,7 @@ use crate::bytes::trace::NUM_ROWS; pub const NUM_BYTE_OPS: usize = 9; /// The number of different byte lookup channels. -pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 4; +pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 16; /// A chip for computing byte operations. /// diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 39f2b72ff5..d6ba3921b6 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -28,7 +28,7 @@ impl MachineAir for ByteChip { } fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option> { - // TODO: We should be able to make this a constant. Also, trace / map should be separate. + // OPT: We should be able to make this a constant. Also, trace / map should be separate. // Since we only need the trace and not the map, we can just pass 0 as the shard. let (trace, _) = Self::trace_and_map(0); diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index fad654de35..60bba4175e 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -3,6 +3,7 @@ use p3_field::AbstractField; use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, OpcodeSelectorCols}; +use crate::operations::BabyBearWordRangeChecker; use crate::{cpu::CpuChip, runtime::Opcode}; impl CpuChip { @@ -57,6 +58,20 @@ impl CpuChip { .when(local.branching) .assert_eq(branch_cols.next_pc.reduce::(), local.next_pc); + // Range check branch_cols.pc and branch_cols.next_pc. + BabyBearWordRangeChecker::::range_check( + builder, + branch_cols.pc, + branch_cols.pc_range_checker, + is_branch_instruction.clone(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + branch_cols.next_pc, + branch_cols.next_pc_range_checker, + is_branch_instruction.clone(), + ); + // When we are branching, calculate branch_cols.next_pc <==> branch_cols.pc + c. builder.send_alu( Opcode::ADD.as_field::(), @@ -65,6 +80,7 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + branch_cols.next_pc_nonce, local.branching, ); @@ -83,15 +99,21 @@ impl CpuChip { .when(local.is_real) .when(local.not_branching) .assert_eq(local.pc + AB::Expr::from_canonical_u8(4), local.next_pc); - } - // Evaluate branching value constraints. - { - // Assert that local.is_branching is a bit. + // Assert that either we are branching or not branching when the instruction is a branch. + builder + .when(is_branch_instruction.clone()) + .assert_one(local.branching + local.not_branching); builder .when(is_branch_instruction.clone()) .assert_bool(local.branching); + builder + .when(is_branch_instruction.clone()) + .assert_bool(local.not_branching); + } + // Evaluate branching value constraints. + { // When the opcode is BEQ and we are branching, assert that a_eq_b is true. builder .when(local.selectors.is_beq * local.branching) @@ -146,6 +168,11 @@ impl CpuChip { .when(is_branch_instruction.clone() * branch_cols.a_eq_b) .assert_word_eq(local.op_a_val(), local.op_b_val()); + // To prevent this ALU send to be arbitrarily large when is_branch_instruction is false. + builder + .when_not(is_branch_instruction.clone()) + .assert_zero(local.branching); + // Calculate a_lt_b <==> a < b (using appropriate signedness). let use_signed_comparison = local.selectors.is_blt + local.selectors.is_bge; builder.send_alu( @@ -157,6 +184,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + branch_cols.a_lt_b_nonce, is_branch_instruction.clone(), ); @@ -169,6 +197,7 @@ impl CpuChip { local.op_a_val(), local.shard, local.channel, + branch_cols.a_gt_b_nonce, is_branch_instruction.clone(), ); } diff --git a/core/src/cpu/air/ecall.rs b/core/src/cpu/air/ecall.rs index 506b2c7b75..870513b83e 100644 --- a/core/src/cpu/air/ecall.rs +++ b/core/src/cpu/air/ecall.rs @@ -35,14 +35,19 @@ impl CpuChip { let syscall_id = syscall_code[0]; let send_to_table = syscall_code[1]; - // When is_ecall_instruction == true AND sent_to_table == true, ecall_mul_send_to_table should be true. - builder - .when(is_ecall_instruction.clone()) - .assert_eq(send_to_table, local.ecall_mul_send_to_table); + // Handle cases: + // - is_ecall_instruction = 1 => ecall_mul_send_to_table == send_to_table + // - is_ecall_instruction = 0 => ecall_mul_send_to_table == 0 + builder.assert_eq( + local.ecall_mul_send_to_table, + send_to_table * is_ecall_instruction.clone(), + ); + builder.send_syscall( local.shard, local.channel, local.clk, + ecall_cols.syscall_nonce, syscall_id, local.op_b_val().reduce::(), local.op_c_val().reduce::(), diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index 6ac1a07c11..707a50ff95 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -5,6 +5,7 @@ use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, MemoryColumns, OpcodeSelectorCols}; use crate::cpu::CpuChip; use crate::memory::MemoryCols; +use crate::operations::BabyBearWordRangeChecker; use crate::runtime::{MemoryAccessPosition, Opcode}; impl CpuChip { @@ -66,6 +67,15 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + memory_columns.addr_word_nonce, + is_memory_instruction.clone(), + ); + + // Range check the addr_word to be a valid babybear word. + BabyBearWordRangeChecker::::range_check( + builder, + memory_columns.addr_word, + memory_columns.addr_word_range_checker, is_memory_instruction.clone(), ); @@ -88,6 +98,35 @@ impl CpuChip { memory_columns.addr_word.reduce::(), ); + // Verify that the least significant byte of addr_word - addr_offset is divisible by 4. + let offset = [ + memory_columns.offset_is_one, + memory_columns.offset_is_two, + memory_columns.offset_is_three, + ] + .iter() + .enumerate() + .fold(AB::Expr::zero(), |acc, (index, &value)| { + acc + AB::Expr::from_canonical_usize(index + 1) * value + }); + let mut recomposed_byte = AB::Expr::zero(); + memory_columns + .aa_least_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder + .when(is_memory_instruction.clone()) + .assert_bool(*value); + + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << (i + 2)) * *value; + }); + + builder + .when(is_memory_instruction.clone()) + .assert_eq(memory_columns.addr_word[0] - offset, recomposed_byte); + // For operations that require reading from memory (not registers), we need to read the // value into the memory columns. builder.eval_memory_access( @@ -98,6 +137,14 @@ impl CpuChip { &memory_columns.memory_access, is_memory_instruction.clone(), ); + + // On memory load instructions, make sure that the memory value is not changed. + builder + .when(self.is_load_instruction::(&local.selectors)) + .assert_word_eq( + *memory_columns.memory_access.value(), + *memory_columns.memory_access.prev_value(), + ); } /// Evaluates constraints related to loading from memory. @@ -121,12 +168,11 @@ impl CpuChip { // Assert that if `is_lb` and `is_lh` are both true, then the most significant byte // matches the value of `local.mem_value_is_neg`. - builder - .when(local.selectors.is_lb + local.selectors.is_lh) - .assert_eq( - local.mem_value_is_neg, - memory_columns.most_sig_byte_decomp[7], - ); + builder.assert_eq( + local.mem_value_is_neg, + (local.selectors.is_lb + local.selectors.is_lh) + * memory_columns.most_sig_byte_decomp[7], + ); // When the memory value is negative, use the SUB opcode to compute the signed value of // the memory value and verify that the op_a value is correct. @@ -143,6 +189,7 @@ impl CpuChip { signed_value, local.shard, local.channel, + local.unsigned_mem_val_nonce, local.mem_value_is_neg, ); @@ -195,6 +242,11 @@ impl CpuChip { .when(local.selectors.is_sh) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + // When the instruction is SW, ensure that the offset is 0. + builder + .when(local.selectors.is_sw) + .assert_one(offset_is_zero.clone()); + // Compute the expected stored value for a SH instruction. let a_is_lower_half = offset_is_zero; let a_is_upper_half = memory_columns.offset_is_two; @@ -247,6 +299,12 @@ impl CpuChip { builder .when(local.selectors.is_lh + local.selectors.is_lhu) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + + // When the instruction is LW, ensure that the offset is zero. + builder + .when(local.selectors.is_lw) + .assert_one(offset_is_zero.clone()); + let use_lower_half = offset_is_zero; let use_upper_half = memory_columns.offset_is_two; let half_value = Word([ @@ -273,9 +331,12 @@ impl CpuChip { local: &CpuCols, unsigned_mem_val: &Word, ) { + let is_mem = self.is_memory_instruction::(&local.selectors); let mut recomposed_byte = AB::Expr::zero(); for i in 0..8 { - builder.assert_bool(memory_columns.most_sig_byte_decomp[i]); + builder + .when(is_mem.clone()) + .assert_bool(memory_columns.most_sig_byte_decomp[i]); recomposed_byte += memory_columns.most_sig_byte_decomp[i] * AB::Expr::from_canonical_u8(1 << i); } diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index 11a985bb5e..4caebbbf73 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -22,13 +22,16 @@ use crate::bytes::ByteOpcode; use crate::cpu::columns::OpcodeSelectorCols; use crate::cpu::columns::{CpuCols, NUM_CPU_COLS}; use crate::cpu::CpuChip; +use crate::operations::BabyBearWordRangeChecker; use crate::runtime::Opcode; use super::columns::eval_channel_selectors; +use super::columns::OPCODE_SELECTORS_COL_MAP; impl Air for CpuChip where AB: SP1AirBuilder + AirBuilderWithPublicValues, + AB::Var: Sized, { #[inline(never)] fn eval(&self, builder: &mut AB) { @@ -84,6 +87,7 @@ where local.op_c_val(), local.shard, local.channel, + local.nonce, is_alu_instruction, ); @@ -121,6 +125,27 @@ where // Check that the is_real flag is correct. self.eval_is_real(builder, local, next); + + // Check that when `is_real=0` that all flags that send interactions are zero. + local + .selectors + .into_iter() + .enumerate() + .for_each(|(i, selector)| { + if i == OPCODE_SELECTORS_COL_MAP.imm_b { + builder + .when(AB::Expr::one() - local.is_real) + .assert_one(local.selectors.imm_b); + } else if i == OPCODE_SELECTORS_COL_MAP.imm_c { + builder + .when(AB::Expr::one() - local.is_real) + .assert_one(local.selectors.imm_c); + } else { + builder + .when(AB::Expr::one() - local.is_real) + .assert_zero(selector); + } + }); } } @@ -175,6 +200,26 @@ impl CpuChip { .when(is_jump_instruction.clone()) .assert_eq(jump_columns.next_pc.reduce::(), local.next_pc); + // Range check op_a, pc, and next_pc. + BabyBearWordRangeChecker::::range_check( + builder, + local.op_a_val(), + jump_columns.op_a_range_checker, + is_jump_instruction.clone(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + jump_columns.pc, + jump_columns.pc_range_checker, + local.selectors.is_jal.into(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + jump_columns.next_pc, + jump_columns.next_pc_range_checker, + is_jump_instruction.clone(), + ); + // Verify that the new pc is calculated correctly for JAL instructions. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -183,6 +228,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + jump_columns.jal_nonce, local.selectors.is_jal, ); @@ -194,6 +240,7 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + jump_columns.jalr_nonce, local.selectors.is_jalr, ); } @@ -208,6 +255,14 @@ impl CpuChip { .when(local.selectors.is_auipc) .assert_eq(auipc_columns.pc.reduce::(), local.pc); + // Range check the pc. + BabyBearWordRangeChecker::::range_check( + builder, + auipc_columns.pc, + auipc_columns.pc_range_checker, + local.selectors.is_auipc.into(), + ); + // Verify that op_a == pc + op_b. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -216,6 +271,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + auipc_columns.auipc_nonce, local.selectors.is_auipc, ); } @@ -288,17 +344,16 @@ impl CpuChip { next: &CpuCols, is_branch_instruction: AB::Expr, ) { - // Verify that if is_sequential_instr is true, assert that local.is_real is true. - // This is needed for the following constraint, which is already degree 3. - builder - .when(local.is_sequential_instr) - .assert_one(local.is_real); - // When is_sequential_instr is true, assert that instruction is not branch, jump, or halt. // Note that the condition `when(local_is_real)` is implied from the previous constraint. let is_halt = self.get_is_halt_syscall::(builder, local); - builder.when(local.is_sequential_instr).assert_zero( - is_branch_instruction + local.selectors.is_jal + local.selectors.is_jalr + is_halt, + builder.when(local.is_real).assert_eq( + local.is_sequential_instr, + AB::Expr::one() + - (is_branch_instruction + + local.selectors.is_jal + + local.selectors.is_jalr + + is_halt), ); // Verify that the pc increments by 4 for all instructions except branch, jump and halt instructions. diff --git a/core/src/cpu/air/register.rs b/core/src/cpu/air/register.rs index e0b989c2bc..23b6551d16 100644 --- a/core/src/cpu/air/register.rs +++ b/core/src/cpu/air/register.rs @@ -57,6 +57,15 @@ impl CpuChip { local.is_real, ); + // Always range check the word value in `op_a`, as JUMP instructions may witness + // an invalid word and write it to memory. + builder.slice_range_check_u8( + &local.op_a_access.access.value.0, + local.shard, + local.channel, + local.is_real, + ); + // If we are performing a branch or a store, then the value of `a` is the previous value. builder .when(is_branch_instruction.clone() + self.is_store_instruction::(&local.selectors)) diff --git a/core/src/cpu/columns/auipc.rs b/core/src/cpu/columns/auipc.rs index a6eb410e7c..fa6871c211 100644 --- a/core/src/cpu/columns/auipc.rs +++ b/core/src/cpu/columns/auipc.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_AUIPC_COLS: usize = size_of::>(); @@ -10,4 +10,6 @@ pub const NUM_AUIPC_COLS: usize = size_of::>(); pub struct AuipcCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, + pub auipc_nonce: T, } diff --git a/core/src/cpu/columns/branch.rs b/core/src/cpu/columns/branch.rs index 06a77ad306..c6298ef0f7 100644 --- a/core/src/cpu/columns/branch.rs +++ b/core/src/cpu/columns/branch.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_BRANCH_COLS: usize = size_of::>(); @@ -11,9 +11,11 @@ pub const NUM_BRANCH_COLS: usize = size_of::>(); pub struct BranchCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, /// The next program counter. pub next_pc: Word, + pub next_pc_range_checker: BabyBearWordRangeChecker, /// Whether a equals b. pub a_eq_b: T, @@ -23,4 +25,13 @@ pub struct BranchCols { /// Whether a is less than b. pub a_lt_b: T, + + /// The nonce of the operation to compute `a_lt_b`. + pub a_lt_b_nonce: T, + + /// The nonce of the operation to compute `a_gt_b`. + pub a_gt_b_nonce: T, + + /// The nonce of the operation to compute `next_pc`. + pub next_pc_nonce: T, } diff --git a/core/src/cpu/columns/ecall.rs b/core/src/cpu/columns/ecall.rs index 927b70614c..5d91622c36 100644 --- a/core/src/cpu/columns/ecall.rs +++ b/core/src/cpu/columns/ecall.rs @@ -26,4 +26,7 @@ pub struct EcallCols { /// Field to store the word index passed into the COMMIT ecall. index_bitmap[word index] should /// be set to 1 and everything else set to 0. pub index_bitmap: [T; PV_DIGEST_NUM_WORDS], + + /// The nonce of the syscall operation. + pub syscall_nonce: T, } diff --git a/core/src/cpu/columns/jump.rs b/core/src/cpu/columns/jump.rs index ca94f3ecac..0e1b5701f5 100644 --- a/core/src/cpu/columns/jump.rs +++ b/core/src/cpu/columns/jump.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_JUMP_COLS: usize = size_of::>(); @@ -10,7 +10,15 @@ pub const NUM_JUMP_COLS: usize = size_of::>(); pub struct JumpCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, - /// THe next program counter. + /// The next program counter. pub next_pc: Word, + pub next_pc_range_checker: BabyBearWordRangeChecker, + + // A range checker for `op_a` which may contain `pc + 4`. + pub op_a_range_checker: BabyBearWordRangeChecker, + + pub jal_nonce: T, + pub jalr_nonce: T, } diff --git a/core/src/cpu/columns/memory.rs b/core/src/cpu/columns/memory.rs index fc54de34c4..baab9e1fc0 100644 --- a/core/src/cpu/columns/memory.rs +++ b/core/src/cpu/columns/memory.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, memory::MemoryReadWriteCols}; +use crate::{air::Word, memory::MemoryReadWriteCols, operations::BabyBearWordRangeChecker}; pub const NUM_MEMORY_COLUMNS: usize = size_of::>(); @@ -17,7 +17,11 @@ pub struct MemoryColumns { // addr_offset = addr_word % 4 // Note that this all needs to be verified in the AIR pub addr_word: Word, + pub addr_word_range_checker: BabyBearWordRangeChecker, + pub addr_aligned: T, + /// The LE bit decomp of the least significant byte of address aligned. + pub aa_least_sig_byte_decomp: [T; 6], pub addr_offset: T, pub memory_access: MemoryReadWriteCols, @@ -28,4 +32,7 @@ pub struct MemoryColumns { // LE bit decomposition for the most significant byte of memory value. This is used to determine // the sign for that value (used for LB and LH). pub most_sig_byte_decomp: [T; 8], + + pub addr_word_nonce: T, + pub unsigned_mem_val_nonce: T, } diff --git a/core/src/cpu/columns/mod.rs b/core/src/cpu/columns/mod.rs index d81bd806fc..968c58362f 100644 --- a/core/src/cpu/columns/mod.rs +++ b/core/src/cpu/columns/mod.rs @@ -40,6 +40,8 @@ pub struct CpuCols { /// The channel value, used for byte lookup multiplicity. pub channel: T, + pub nonce: T, + /// The clock cycle value. This should be within 24 bits. pub clk: T, /// The least significant 16 bit limb of clk. @@ -97,6 +99,8 @@ pub struct CpuCols { /// memory opcodes (i.e. LB, LH, LW, LBU, and LHU). pub unsigned_mem_val: Word, + pub unsigned_mem_val_nonce: T, + /// The result of selectors.is_ecall * the send_to_table column for the ECALL opcode. pub ecall_mul_send_to_table: T, diff --git a/core/src/cpu/columns/opcode.rs b/core/src/cpu/columns/opcode.rs index ac67c6934e..80fd63ad3d 100644 --- a/core/src/cpu/columns/opcode.rs +++ b/core/src/cpu/columns/opcode.rs @@ -1,11 +1,23 @@ use p3_field::PrimeField; use sp1_derive::AlignedBorrow; -use std::mem::size_of; +use std::mem::{size_of, transmute}; use std::vec::IntoIter; -use crate::runtime::{Instruction, Opcode}; +use crate::{ + runtime::{Instruction, Opcode}, + utils::indices_arr, +}; pub const NUM_OPCODE_SELECTOR_COLS: usize = size_of::>(); +pub const OPCODE_SELECTORS_COL_MAP: OpcodeSelectorCols = make_selectors_col_map(); + +/// Creates the column map for the CPU. +const fn make_selectors_col_map() -> OpcodeSelectorCols { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_OPCODE_SELECTOR_COLS], OpcodeSelectorCols>(indices_arr) + } +} /// The column layout for opcode selectors. #[derive(AlignedBorrow, Clone, Copy, Default, Debug)] @@ -98,7 +110,7 @@ impl IntoIterator for OpcodeSelectorCols { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - vec![ + let columns = vec![ self.imm_b, self.imm_c, self.is_alu, @@ -121,7 +133,8 @@ impl IntoIterator for OpcodeSelectorCols { self.is_jal, self.is_auipc, self.is_unimpl, - ] - .into_iter() + ]; + assert_eq!(columns.len(), NUM_OPCODE_SELECTOR_COLS); + columns.into_iter() } } diff --git a/core/src/cpu/event.rs b/core/src/cpu/event.rs index 2170d91d5d..cdd38f4765 100644 --- a/core/src/cpu/event.rs +++ b/core/src/cpu/event.rs @@ -51,4 +51,15 @@ pub struct CpuEvent { /// Exit code called with halt. pub exit_code: u32, + + pub alu_lookup_id: usize, + pub syscall_lookup_id: usize, + pub memory_add_lookup_id: usize, + pub memory_sub_lookup_id: usize, + pub branch_gt_lookup_id: usize, + pub branch_lt_lookup_id: usize, + pub branch_add_lookup_id: usize, + pub jump_jal_lookup_id: usize, + pub jump_jalr_lookup_id: usize, + pub auipc_lookup_id: usize, } diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index b65c4e43ca..893faa385e 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -1,3 +1,4 @@ +use std::array; use std::borrow::BorrowMut; use std::collections::HashMap; @@ -11,6 +12,8 @@ use tracing::instrument; use super::columns::{CPU_COL_MAP, NUM_CPU_COLS}; use super::{CpuChip, CpuEvent}; use crate::air::MachineAir; +use crate::air::Word; +use crate::alu::create_alu_lookups; use crate::alu::{self, AluEvent}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -42,7 +45,7 @@ impl MachineAir for CpuChip { let mut rows_with_events = input .cpu_events .par_iter() - .map(|op: &CpuEvent| self.event_to_row::(*op)) + .map(|op: &CpuEvent| self.event_to_row::(*op, &input.nonce_lookup)) .collect::>(); // No need to sort by the shard, since the cpu events are already partitioned by that. @@ -91,7 +94,7 @@ impl MachineAir for CpuChip { let mut alu = HashMap::new(); let mut blu: Vec<_> = Vec::default(); ops.iter().for_each(|op| { - let (_, alu_events, blu_events) = self.event_to_row::(*op); + let (_, alu_events, blu_events) = self.event_to_row::(*op, &HashMap::new()); alu_events.into_iter().for_each(|(key, value)| { alu.entry(key).or_insert(Vec::default()).extend(value); }); @@ -124,6 +127,7 @@ impl CpuChip { fn event_to_row( &self, event: CpuEvent, + nonce_lookup: &HashMap, ) -> ( [F; NUM_CPU_COLS], HashMap>, @@ -138,6 +142,14 @@ impl CpuChip { // Populate shard and clk columns. self.populate_shard_clk(cols, event, &mut new_blu_events); + // Populate the nonce. + cols.nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.alu_lookup_id) + .copied() + .unwrap_or_default(), + ); + // Populate basic fields. cols.pc = F::from_canonical_u32(event.pc); cols.next_pc = F::from_canonical_u32(event.next_pc); @@ -150,17 +162,45 @@ impl CpuChip { // Populate memory accesses for a, b, and c. if let Some(record) = event.a_record { cols.op_a_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } if let Some(MemoryRecordEnum::Read(record)) = event.b_record { cols.op_b_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } if let Some(MemoryRecordEnum::Read(record)) = event.c_record { cols.op_c_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } + // Populate range checks for a. + let a_bytes = cols + .op_a_access + .access + .value + .0 + .iter() + .map(|x| x.as_canonical_u32()) + .collect::>(); + new_blu_events.push(ByteLookupEvent { + shard: event.shard, + channel: event.channel, + opcode: ByteOpcode::U8Range, + a1: 0, + a2: 0, + b: a_bytes[0], + c: a_bytes[1], + }); + new_blu_events.push(ByteLookupEvent { + shard: event.shard, + channel: event.channel, + opcode: ByteOpcode::U8Range, + a1: 0, + a2: 0, + b: a_bytes[2], + c: a_bytes[3], + }); + // Populate memory accesses for reading from memory. assert_eq!(event.memory_record.is_some(), event.memory.is_some()); let memory_columns = cols.opcode_specific_columns.memory_mut(); @@ -171,19 +211,23 @@ impl CpuChip { } // Populate memory, branch, jump, and auipc specific fields. - self.populate_memory(cols, event, &mut new_alu_events, &mut new_blu_events); - self.populate_branch(cols, event, &mut new_alu_events); - self.populate_jump(cols, event, &mut new_alu_events); - self.populate_auipc(cols, event, &mut new_alu_events); - let is_halt = self.populate_ecall(cols, event); - - if !event.instruction.is_branch_instruction() - && !event.instruction.is_jump_instruction() - && !event.instruction.is_ecall_instruction() - && !is_halt - { - cols.is_sequential_instr = F::one(); - } + self.populate_memory( + cols, + event, + &mut new_alu_events, + &mut new_blu_events, + nonce_lookup, + ); + self.populate_branch(cols, event, &mut new_alu_events, nonce_lookup); + self.populate_jump(cols, event, &mut new_alu_events, nonce_lookup); + self.populate_auipc(cols, event, &mut new_alu_events, nonce_lookup); + let is_halt = self.populate_ecall(cols, event, nonce_lookup); + + cols.is_sequential_instr = F::from_bool( + !event.instruction.is_branch_instruction() + && !event.instruction.is_jump_instruction() + && !is_halt, + ); // Assert that the instruction is not a no-op. cols.is_real = F::one(); @@ -243,6 +287,7 @@ impl CpuChip { event: CpuEvent, new_alu_events: &mut HashMap>, new_blu_events: &mut Vec, + nonce_lookup: &HashMap, ) { if !matches!( event.instruction.opcode, @@ -261,12 +306,20 @@ impl CpuChip { // Populate addr_word and addr_aligned columns. let memory_columns = cols.opcode_specific_columns.memory_mut(); let memory_addr = event.b.wrapping_add(event.c); + let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32; memory_columns.addr_word = memory_addr.into(); - memory_columns.addr_aligned = - F::from_canonical_u32(memory_addr - memory_addr % WORD_SIZE as u32); + memory_columns.addr_word_range_checker.populate(memory_addr); + memory_columns.addr_aligned = F::from_canonical_u32(aligned_addr); + + // Populate the aa_least_sig_byte_decomp columns. + assert!(aligned_addr % 4 == 0); + let aligned_addr_ls_byte = (aligned_addr & 0x000000FF) as u8; + let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0); + memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2])); // Add event to ALU check to check that addr == b + c let add_event = AluEvent { + lookup_id: event.memory_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -274,11 +327,18 @@ impl CpuChip { a: memory_addr, b: event.b, c: event.c, + sub_lookups: create_alu_lookups(), }; new_alu_events .entry(Opcode::ADD) .and_modify(|op_new_events| op_new_events.push(add_event)) .or_insert(vec![add_event]); + memory_columns.addr_word_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.memory_add_lookup_id) + .copied() + .unwrap_or_default(), + ); // Populate memory offsets. let addr_offset = (memory_addr % WORD_SIZE as u32) as u8; @@ -332,6 +392,7 @@ impl CpuChip { if memory_columns.most_sig_byte_decomp[7] == F::one() { cols.mem_value_is_neg = F::one(); let sub_event = AluEvent { + lookup_id: event.memory_sub_lookup_id, channel: event.channel, shard: event.shard, clk: event.clk, @@ -339,7 +400,14 @@ impl CpuChip { a: event.a, b: cols.unsigned_mem_val.to_u32(), c: sign_value, + sub_lookups: create_alu_lookups(), }; + cols.unsigned_mem_val_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.memory_sub_lookup_id) + .copied() + .unwrap_or_default(), + ); new_alu_events .entry(Opcode::SUB) @@ -370,6 +438,7 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if event.instruction.is_branch_instruction() { let branch_columns = cols.opcode_specific_columns.branch_mut(); @@ -395,8 +464,10 @@ impl CpuChip { } else { Opcode::SLTU }; + // Add the ALU events for the comparisons let lt_comp_event = AluEvent { + lookup_id: event.branch_lt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -404,7 +475,14 @@ impl CpuChip { a: a_lt_b as u32, b: event.a, c: event.b, + sub_lookups: create_alu_lookups(), }; + branch_columns.a_lt_b_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_lt_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(alu_op_code) @@ -412,6 +490,7 @@ impl CpuChip { .or_insert(vec![lt_comp_event]); let gt_comp_event = AluEvent { + lookup_id: event.branch_gt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -419,7 +498,14 @@ impl CpuChip { a: a_gt_b as u32, b: event.b, c: event.a, + sub_lookups: create_alu_lookups(), }; + branch_columns.a_gt_b_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_gt_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(alu_op_code) @@ -438,14 +524,17 @@ impl CpuChip { _ => unreachable!(), }; - if branching { - let next_pc = event.pc.wrapping_add(event.c); + let next_pc = event.pc.wrapping_add(event.c); + branch_columns.pc = Word::from(event.pc); + branch_columns.next_pc = Word::from(next_pc); + branch_columns.pc_range_checker.populate(event.pc); + branch_columns.next_pc_range_checker.populate(next_pc); + if branching { cols.branching = F::one(); - branch_columns.pc = event.pc.into(); - branch_columns.next_pc = next_pc.into(); let add_event = AluEvent { + lookup_id: event.branch_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -453,7 +542,14 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.c, + sub_lookups: create_alu_lookups(), }; + branch_columns.next_pc_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_add_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -471,6 +567,7 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if event.instruction.is_jump_instruction() { let jump_columns = cols.opcode_specific_columns.jump_mut(); @@ -478,10 +575,14 @@ impl CpuChip { match event.instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); - jump_columns.pc = event.pc.into(); - jump_columns.next_pc = next_pc.into(); + jump_columns.op_a_range_checker.populate(event.a); + jump_columns.pc = Word::from(event.pc); + jump_columns.pc_range_checker.populate(event.pc); + jump_columns.next_pc = Word::from(next_pc); + jump_columns.next_pc_range_checker.populate(next_pc); let add_event = AluEvent { + lookup_id: event.jump_jal_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -489,7 +590,14 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.b, + sub_lookups: create_alu_lookups(), }; + jump_columns.jal_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.jump_jal_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -498,9 +606,12 @@ impl CpuChip { } Opcode::JALR => { let next_pc = event.b.wrapping_add(event.c); - jump_columns.next_pc = next_pc.into(); + jump_columns.op_a_range_checker.populate(event.a); + jump_columns.next_pc = Word::from(next_pc); + jump_columns.next_pc_range_checker.populate(next_pc); let add_event = AluEvent { + lookup_id: event.jump_jalr_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -508,7 +619,14 @@ impl CpuChip { a: next_pc, b: event.b, c: event.c, + sub_lookups: create_alu_lookups(), }; + jump_columns.jalr_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.jump_jalr_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -526,13 +644,16 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if matches!(event.instruction.opcode, Opcode::AUIPC) { let auipc_columns = cols.opcode_specific_columns.auipc_mut(); - auipc_columns.pc = event.pc.into(); + auipc_columns.pc = Word::from(event.pc); + auipc_columns.pc_range_checker.populate(event.pc); let add_event = AluEvent { + lookup_id: event.auipc_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -540,7 +661,14 @@ impl CpuChip { a: event.a, b: event.pc, c: event.b, + sub_lookups: create_alu_lookups(), }; + auipc_columns.auipc_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.auipc_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -550,7 +678,12 @@ impl CpuChip { } /// Populate columns related to ECALL. - fn populate_ecall(&self, cols: &mut CpuCols, _: CpuEvent) -> bool { + fn populate_ecall( + &self, + cols: &mut CpuCols, + event: CpuEvent, + nonce_lookup: &HashMap, + ) -> bool { let mut is_halt = false; if cols.selectors.is_ecall == F::one() { @@ -604,6 +737,14 @@ impl CpuChip { ecall_cols.index_bitmap[digest_idx] = F::one(); } + // Write the syscall nonce. + ecall_cols.syscall_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.syscall_lookup_id) + .copied() + .unwrap_or_default(), + ); + is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()); } @@ -640,41 +781,41 @@ mod tests { use super::*; - use crate::runtime::{tests::simple_program, Instruction, Runtime}; + use crate::runtime::{tests::simple_program, Runtime}; use crate::utils::{run_test, setup_logger, SP1CoreOpts}; - #[test] - fn generate_trace() { - let mut shard = ExecutionRecord::default(); - shard.cpu_events = vec![CpuEvent { - shard: 1, - channel: 0, - clk: 6, - pc: 1, - next_pc: 5, - instruction: Instruction { - opcode: Opcode::ADD, - op_a: 0, - op_b: 1, - op_c: 2, - imm_b: false, - imm_c: false, - }, - a: 1, - a_record: None, - b: 2, - b_record: None, - c: 3, - c_record: None, - memory: None, - memory_record: None, - exit_code: 0, - }]; - let chip = CpuChip::default(); - let trace: RowMajorMatrix = - chip.generate_trace(&shard, &mut ExecutionRecord::default()); - println!("{:?}", trace.values); - } + // #[test] + // fn generate_trace() { + // let mut shard = ExecutionRecord::default(); + // shard.cpu_events = vec![CpuEvent { + // shard: 1, + // channel: 0, + // clk: 6, + // pc: 1, + // next_pc: 5, + // instruction: Instruction { + // opcode: Opcode::ADD, + // op_a: 0, + // op_b: 1, + // op_c: 2, + // imm_b: false, + // imm_c: false, + // }, + // a: 1, + // a_record: None, + // b: 2, + // b_record: None, + // c: 3, + // c_record: None, + // memory: None, + // memory_record: None, + // exit_code: 0, + // }]; + // let chip = CpuChip::default(); + // let trace: RowMajorMatrix = + // chip.generate_trace(&shard, &mut ExecutionRecord::default()); + // println!("{:?}", trace.values); + // } #[test] fn generate_trace_simple_program() { diff --git a/core/src/io.rs b/core/src/io.rs index b07643bbd3..2c4a87953c 100644 --- a/core/src/io.rs +++ b/core/src/io.rs @@ -7,7 +7,7 @@ use num_bigint::BigUint; use serde::{de::DeserializeOwned, Deserialize, Serialize}; /// Standard input for the prover. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct SP1Stdin { /// Input stored as a vec of vec of bytes. It's stored this way because the read syscall reads /// a vec of bytes at a time. @@ -20,7 +20,7 @@ pub struct SP1Stdin { } /// Public values for the prover. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct SP1PublicValues { buffer: Buffer, } @@ -45,7 +45,7 @@ impl SP1Stdin { } /// Read a value from the buffer. - pub fn read(&mut self) -> T { + pub fn read(&mut self) -> T { let result: T = bincode::deserialize(&self.buffer[self.ptr]).expect("failed to deserialize"); self.ptr += 1; @@ -121,7 +121,7 @@ impl SP1PublicValues { } /// Write a value to the buffer. - pub fn write(&mut self, data: &T) { + pub fn write(&mut self, data: &T) { self.buffer.write(data); } diff --git a/core/src/lib.rs b/core/src/lib.rs index da90db9cf4..8a862fcc2b 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -33,3 +33,10 @@ pub mod utils; #[allow(unused_imports)] use runtime::{Program, Runtime}; use stark::StarkGenericConfig; + +/// The global version for all components of SP1. +/// +/// This string should be updated whenever any step in verifying an SP1 proof changes, including +/// core, recursion, and plonk-bn254. This string is used to download SP1 artifacts and the gnark +/// docker image. +pub const SP1_CIRCUIT_VERSION: &str = "v1.0.7-testnet"; diff --git a/core/src/lookup/interaction.rs b/core/src/lookup/interaction.rs index 74b7a9fc06..1c20938cc4 100644 --- a/core/src/lookup/interaction.rs +++ b/core/src/lookup/interaction.rs @@ -74,7 +74,6 @@ impl Interaction { } } -// TODO: add debug for VirtualPairCol so that we can derive Debug for Interaction. impl Debug for Interaction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Interaction") diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 3786dd4ca1..a6cf49d02a 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -1,5 +1,6 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use std::array; use p3_air::BaseAir; use p3_air::{Air, AirBuilder}; @@ -10,8 +11,9 @@ use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; use super::MemoryInitializeFinalizeEvent; -use crate::air::{AirInteraction, SP1AirBuilder, Word}; -use crate::air::{MachineAir, WordAirBuilder}; +use crate::air::MachineAir; +use crate::air::{AirInteraction, BaseAirBuilder, SP1AirBuilder}; +use crate::operations::BabyBearBitDecomposition; use crate::runtime::{ExecutionRecord, Program}; use crate::utils::pad_to_power_of_two; @@ -62,7 +64,7 @@ impl MachineAir for MemoryChip { MemoryChipType::Finalize => input.memory_finalize_events.clone(), }; memory_events.sort_by_key(|event| event.addr); - let rows: Vec<[F; 8]> = (0..memory_events.len()) // TODO: change this back to par_iter + let rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = (0..memory_events.len()) // OPT: change this to par_iter .map(|i| { let MemoryInitializeFinalizeEvent { addr, @@ -71,14 +73,37 @@ impl MachineAir for MemoryChip { timestamp, used, } = memory_events[i]; + let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; let cols: &mut MemoryInitCols = row.as_mut_slice().borrow_mut(); cols.addr = F::from_canonical_u32(addr); + cols.addr_bits.populate(addr); cols.shard = F::from_canonical_u32(shard); cols.timestamp = F::from_canonical_u32(timestamp); - cols.value = value.into(); + cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); cols.is_real = F::from_canonical_u32(used); + if i != memory_events.len() - 1 { + let next_addr = memory_events[i + 1].addr; + assert_ne!(next_addr, addr); + + cols.addr_bits.populate(addr); + + cols.seen_diff_bits[0] = F::zero(); + for j in 0..32 { + let rev_j = 32 - j - 1; + let next_bit = ((next_addr >> rev_j) & 1) == 1; + let local_bit = ((addr >> rev_j) & 1) == 1; + cols.match_bits[j] = + F::from_bool((local_bit && next_bit) || (!local_bit && !next_bit)); + cols.seen_diff_bits[j + 1] = cols.seen_diff_bits[j] + + (F::one() - cols.seen_diff_bits[j]) * (F::one() - cols.match_bits[j]); + cols.not_match_and_not_seen_diff_bits[j] = + (F::one() - cols.match_bits[j]) * (F::one() - cols.seen_diff_bits[j]); + } + assert_eq!(cols.seen_diff_bits[cols.seen_diff_bits.len() - 1], F::one()); + } + row }) .collect::>(); @@ -101,7 +126,7 @@ impl MachineAir for MemoryChip { } } -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] pub struct MemoryInitCols { /// The shard number of the memory access. @@ -113,8 +138,20 @@ pub struct MemoryInitCols { /// The address of the memory access. pub addr: T, + /// A bit decomposition of `addr`. + pub addr_bits: BabyBearBitDecomposition, + + // Whether the i'th bit matches the next addr's bit. + pub match_bits: [T; 32], + + // Whether we've seen a different bit in the comparison. + pub seen_diff_bits: [T; 33], + + // Whether the i'th bit doesn't match the next addr's bit and we haven't seen a diff bitn yet. + pub not_match_and_not_seen_diff_bits: [T; 32], + /// The value of the memory access. - pub value: Word, + pub value: [T; 32], /// Whether the memory access is a real access. pub is_real: T, @@ -130,10 +167,29 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MemoryInitCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &MemoryInitCols = (*next).borrow(); + + builder.assert_bool(local.is_real); + for i in 0..32 { + builder.assert_bool(local.value[i]); + } + + let mut byte1 = AB::Expr::zero(); + let mut byte2 = AB::Expr::zero(); + let mut byte3 = AB::Expr::zero(); + let mut byte4 = AB::Expr::zero(); + for i in 0..8 { + byte1 += local.value[i].into() * AB::F::from_canonical_u8(1 << i); + byte2 += local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i); + byte3 += local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i); + byte4 += local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i); + } + let value = [byte1, byte2, byte3, byte4]; if self.kind == MemoryChipType::Initialize { let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), local.addr.into()]; - values.extend(local.value.map(Into::into)); + values.extend(value.map(Into::into)); builder.receive(AirInteraction::new( values, local.is_real.into(), @@ -145,7 +201,7 @@ where local.timestamp.into(), local.addr.into(), ]; - values.extend(local.value.map(Into::into)); + values.extend(value); builder.send(AirInteraction::new( values, local.is_real.into(), @@ -153,16 +209,106 @@ where )); } + // We want to assert addr < addr'. Assume seen_diff_0 = 0. + // + // match_i = (addr_i & addr'_i) || (!addr_i & !addr'_i) + // => + // match_i == addr_i * addr_i + (1 - addr_i) * (1 - addr'_i) + // + // when !match_i and !seen_diff_i, then enforce (addr_i == 0) and (addr'_i == 1). + // if seen_diff_i: + // seen_diff_{i+1} = 1 + // else: + // seen_diff_{i+1} = !match_i + // => + // builder.when(!match_i * !seen_diff_i).assert_zero(addr_i) + // builder.when(!match_i * !seen_diff_i).assert_one(addr'_i) + // seen_diff_bit_{i+1} == seen_diff_i + (1-seen_diff_i) * (1 - match_i) + // + // at the end of the algorithm, assert that we've seen a diff bit. + // => + // seen_diff_bit_{last} == 1 + + // Assert that we start with assuming that we haven't seen a diff bit. + builder.assert_zero(local.seen_diff_bits[0]); + + for i in 0..local.addr_bits.bits.len() { + // Compute the i'th msb bit's index. + let rev_i = local.addr_bits.bits.len() - i - 1; + + // Compute whether the i'th msb bit matches. + let match_i = local.addr_bits.bits[rev_i] * next.addr_bits.bits[rev_i] + + (AB::Expr::one() - local.addr_bits.bits[rev_i]) + * (AB::Expr::one() - next.addr_bits.bits[rev_i]); + builder + .when_transition() + .when(next.is_real) + .assert_eq(match_i.clone(), local.match_bits[i]); + + // Compute whether it's not a match and we haven't seen a diff bit. + let not_match_and_not_seen_diff_i = (AB::Expr::one() - local.match_bits[i]) + * (AB::Expr::one() - local.seen_diff_bits[i]); + builder.when_transition().when(next.is_real).assert_eq( + local.not_match_and_not_seen_diff_bits[i], + not_match_and_not_seen_diff_i, + ); + + // If the i'th msb bit doesn't match and it's the first time we've seen a diff bit, + // then enforce that the next bit is one and the current bit is zero. + builder + .when_transition() + .when(local.not_match_and_not_seen_diff_bits[i]) + .when(next.is_real) + .assert_zero(local.addr_bits.bits[rev_i]); + builder + .when_transition() + .when(local.not_match_and_not_seen_diff_bits[i]) + .when(next.is_real) + .assert_one(next.addr_bits.bits[rev_i]); + + // Update the seen diff bits. + builder.when_transition().assert_eq( + local.seen_diff_bits[i + 1], + local.seen_diff_bits[i] + local.not_match_and_not_seen_diff_bits[i], + ); + } + + // Assert that on rows where the next row is real, we've seen a diff bit. + builder + .when_transition() + .when(next.is_real) + .assert_one(local.seen_diff_bits[local.addr_bits.bits.len()]); + + // Canonically decompose the address into bits so we can do comparisons. + BabyBearBitDecomposition::::range_check( + builder, + local.addr, + local.addr_bits, + local.is_real.into(), + ); + + // Assert that the real rows are all padded to the top. + builder + .when_transition() + .when_not(local.is_real) + .assert_zero(next.is_real); + + if self.kind == MemoryChipType::Initialize { + builder + .when(local.is_real) + .assert_eq(local.timestamp, AB::F::one()); + } + // Register %x0 should always be 0. See 2.6 Load and Store Instruction on // P.18 of the RISC-V spec. To ensure that, we expect that the first row of the Initialize // and Finalize global memory chip is for register %x0 (i.e. addr = 0x0), and that those rows // have a value of 0. Additionally, in the CPU air, we ensure that whenever op_a is set to // %x0, its value is 0. - // - // TODO: Add a similar check for MemoryChipType::Initialize. - if self.kind == MemoryChipType::Finalize { + if self.kind == MemoryChipType::Initialize || self.kind == MemoryChipType::Finalize { builder.when_first_row().assert_zero(local.addr); - builder.when_first_row().assert_word_zero(local.value); + for i in 0..32 { + builder.when_first_row().assert_zero(local.value[i]); + } } } } diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs index 7acdee1fb8..4246db14c8 100644 --- a/core/src/memory/mod.rs +++ b/core/src/memory/mod.rs @@ -27,7 +27,7 @@ impl MemoryInitializeFinalizeEvent { addr, value, shard: 0, - timestamp: 0, + timestamp: 1, used: if used { 1 } else { 0 }, } } diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index 3d922c4ae6..ace3a6c219 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -31,9 +31,12 @@ pub struct MemoryProgramPreprocessedCols { #[derive(AlignedBorrow, Clone, Copy, Default)] #[repr(C)] pub struct MemoryProgramMultCols { - /// The multiplicity of the event, must be 1 in the first shard and 0 otherwise. + /// The multiplicity of the event. + /// + /// This column is technically redundant with `is_real`, but it's included for clarity. pub multiplicity: T, - /// Columns to see if current shard is 1. + + /// Whether the shard is the first shard. pub is_first_shard: IsZeroOperation, } @@ -74,7 +77,6 @@ impl MachineAir for MemoryProgramChip { cols.addr = F::from_canonical_u32(addr); cols.value = Word::from(word); cols.is_real = F::one(); - row }) .collect::>(); @@ -120,8 +122,7 @@ impl MachineAir for MemoryProgramChip { let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS]; let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; - IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1); - + cols.is_first_shard.populate(input.index - 1); row }) .collect::>(); @@ -171,23 +172,26 @@ where .map(|elm| (*elm).into()) .collect::>(), ); + + // Constrain `is_first_shard` to be 1 if and only if the shard is the first shard. IsZeroOperation::::eval( builder, - public_values.shard - AB::Expr::one(), + public_values.shard - AB::F::one(), mult_local.is_first_shard, prep_local.is_real.into(), ); - let is_first_shard = mult_local.is_first_shard.result; // Multiplicity must be either 0 or 1. builder.assert_bool(mult_local.multiplicity); + // If first shard and preprocessed is real, multiplicity must be one. builder - .when(is_first_shard * prep_local.is_real) - .assert_one(mult_local.multiplicity); - // If not first shard or preprocessed is not real, multiplicity must be zero. + .when(mult_local.is_first_shard.result) + .assert_eq(mult_local.multiplicity, prep_local.is_real.into()); + + // If it's not the first shard, then the multiplicity must be zero. builder - .when((AB::Expr::one() - is_first_shard) + (AB::Expr::one() - prep_local.is_real)) + .when_not(mult_local.is_first_shard.result) .assert_zero(mult_local.multiplicity); let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()]; diff --git a/core/src/operations/baby_bear_range.rs b/core/src/operations/baby_bear_range.rs new file mode 100644 index 0000000000..7e1ad0ef42 --- /dev/null +++ b/core/src/operations/baby_bear_range.rs @@ -0,0 +1,88 @@ +use std::array; + +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use crate::stark::SP1AirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BabyBearBitDecomposition { + /// The bit decoposition of the`value`. + pub bits: [T; 32], + + /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_5: T, + + /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_6: T, + + /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_7: T, +} + +impl BabyBearBitDecomposition { + pub fn populate(&mut self, value: u32) { + self.bits = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); + let most_sig_byte_decomp = &self.bits[24..32]; + self.and_most_sig_byte_decomp_3_to_5 = most_sig_byte_decomp[3] * most_sig_byte_decomp[4]; + self.and_most_sig_byte_decomp_3_to_6 = + self.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5]; + self.and_most_sig_byte_decomp_3_to_7 = + self.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6]; + } + + pub fn range_check( + builder: &mut AB, + value: AB::Var, + cols: BabyBearBitDecomposition, + is_real: AB::Expr, + ) { + let mut reconstructed_value = AB::Expr::zero(); + for (i, bit) in cols.bits.iter().enumerate() { + builder.when(is_real.clone()).assert_bool(*bit); + reconstructed_value += AB::Expr::from_wrapped_u32(1 << i) * *bit; + } + + // Assert that bits2num(bits) == value. + builder + .when(is_real.clone()) + .assert_eq(reconstructed_value, value); + + // Range check that value is less than baby bear modulus. To do this, it is sufficient + // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) + // 01111000_00000000_00000000_00000001. So we need to check the following conditions: + // 1) if most_sig_byte > 01111000, then fail. + // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. + // 3) if most_sig_byte < 01111000, then pass. + let most_sig_byte_decomp = &cols.bits[24..32]; + builder + .when(is_real.clone()) + .assert_zero(most_sig_byte_decomp[7]); + + // Compute the product of the "top bits". + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_5, + most_sig_byte_decomp[3] * most_sig_byte_decomp[4], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_6, + cols.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_7, + cols.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6], + ); + + // If the top bits are all 0, then the lower bits must all be 0. + let mut lower_bits_sum: AB::Expr = AB::Expr::zero(); + for bit in cols.bits[0..27].iter() { + lower_bits_sum = lower_bits_sum + *bit; + } + builder + .when(is_real) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(lower_bits_sum); + } +} diff --git a/core/src/operations/baby_bear_word.rs b/core/src/operations/baby_bear_word.rs new file mode 100644 index 0000000000..2e773b3e6d --- /dev/null +++ b/core/src/operations/baby_bear_word.rs @@ -0,0 +1,94 @@ +use std::array; + +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use crate::{air::Word, stark::SP1AirBuilder}; + +/// A set of columns needed to compute the add of two words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BabyBearWordRangeChecker { + /// Most sig byte LE bit decomposition. + pub most_sig_byte_decomp: [T; 8], + + /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_5: T, + + /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_6: T, + + /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_7: T, +} + +impl BabyBearWordRangeChecker { + pub fn populate(&mut self, value: u32) { + self.most_sig_byte_decomp = array::from_fn(|i| F::from_bool(value & (1 << (i + 24)) != 0)); + self.and_most_sig_byte_decomp_3_to_5 = + self.most_sig_byte_decomp[3] * self.most_sig_byte_decomp[4]; + self.and_most_sig_byte_decomp_3_to_6 = + self.and_most_sig_byte_decomp_3_to_5 * self.most_sig_byte_decomp[5]; + self.and_most_sig_byte_decomp_3_to_7 = + self.and_most_sig_byte_decomp_3_to_6 * self.most_sig_byte_decomp[6]; + } + + pub fn range_check( + builder: &mut AB, + value: Word, + cols: BabyBearWordRangeChecker, + is_real: AB::Expr, + ) { + let mut recomposed_byte = AB::Expr::zero(); + cols.most_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder.when(is_real.clone()).assert_bool(*value); + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << i) * *value; + }); + + builder + .when(is_real.clone()) + .assert_eq(recomposed_byte, value[3]); + + // Range check that value is less than baby bear modulus. To do this, it is sufficient + // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) + // 01111000_00000000_00000000_00000001. So we need to check the following conditions: + // 1) if most_sig_byte > 01111000, then fail. + // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. + // 3) if most_sig_byte < 01111000, then pass. + builder + .when(is_real.clone()) + .assert_zero(cols.most_sig_byte_decomp[7]); + + // Compute the product of the "top bits". + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_5, + cols.most_sig_byte_decomp[3] * cols.most_sig_byte_decomp[4], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_6, + cols.and_most_sig_byte_decomp_3_to_5 * cols.most_sig_byte_decomp[5], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_7, + cols.and_most_sig_byte_decomp_3_to_6 * cols.most_sig_byte_decomp[6], + ); + + let bottom_bits: AB::Expr = cols.most_sig_byte_decomp[0..3] + .iter() + .map(|bit| (*bit).into()) + .sum(); + builder + .when(is_real.clone()) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(bottom_bits); + builder + .when(is_real) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(value[0] + value[1] + value[2]); + } +} diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index ae04e2b9b1..995142c2f2 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -445,7 +445,6 @@ mod tests { let mut challenger = config.challenger(); - // TODO: test with other fields let chip: FieldOpChip = FieldOpChip::new(*op); let shard = ExecutionRecord::default(); let trace: RowMajorMatrix = diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index c0401a1d46..e16de147bd 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -83,6 +83,20 @@ impl FieldSqrtCols { }; record.add_byte_lookup_event(and_event); + // Add the byte range check for `sqrt`. + record.add_u8_range_checks( + shard, + channel, + self.multiplication + .result + .0 + .as_slice() + .iter() + .map(|x| x.as_canonical_u32() as u8) + .collect::>() + .as_slice(), + ); + sqrt } } @@ -129,6 +143,14 @@ where is_real.clone(), ); + // Range check that `sqrt` limbs are bytes. + builder.slice_range_check_u8( + sqrt.0.as_slice(), + shard.clone(), + channel.clone(), + is_real.clone(), + ); + // Assert that the square root is the positive one, i.e., with least significant bit 0. // This is done by computing LSB = least_significant_byte & 1. builder.assert_bool(self.lsb); diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 242c9100b1..e3fbcc78b1 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -8,6 +8,8 @@ mod add; mod add4; mod add5; mod and; +mod baby_bear_range; +mod baby_bear_word; pub mod field; mod fixed_rotate_right; mod fixed_shift_right; @@ -22,6 +24,8 @@ pub use add::*; pub use add4::*; pub use add5::*; pub use and::*; +pub use baby_bear_range::*; +pub use baby_bear_word::*; pub use fixed_rotate_right::*; pub use fixed_shift_right::*; pub use is_equal_word::*; diff --git a/core/src/operations/or.rs b/core/src/operations/or.rs index 8cb3f00191..b30821532a 100644 --- a/core/src/operations/or.rs +++ b/core/src/operations/or.rs @@ -10,8 +10,6 @@ use crate::disassembler::WORD_SIZE; use crate::runtime::ExecutionRecord; /// A set of columns needed to compute the or of two words. -/// -/// TODO: This is currently not in use, and thus not tested thoroughly yet. #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct OrOperation { diff --git a/core/src/runtime/hooks.rs b/core/src/runtime/hooks.rs new file mode 100644 index 0000000000..e9f503076f --- /dev/null +++ b/core/src/runtime/hooks.rs @@ -0,0 +1,123 @@ +use std::collections::HashMap; + +use k256::ecdsa::{RecoveryId, Signature, VerifyingKey}; +use k256::elliptic_curve::ops::Invert; + +use super::Runtime; + +pub trait Hook: Fn(HookEnv, &[u8]) -> Vec> + Send {} + +impl Vec> + Send> Hook for F {} + +pub type BoxedHook<'a> = Box; + +/// The file descriptor through which to access `hook_ecrecover`. +pub const FD_ECRECOVER_HOOK: u32 = 5; + +/// A registry of hooks to call, indexed by the file descriptors through which they are accessed. +pub struct HookRegistry<'a> { + /// Table of registered hooks. Prefer using `Runtime::invoke_hook` and + /// `HookRegistry::register` over interacting with this field directly. + pub table: HashMap>, +} + +impl<'a> HookRegistry<'a> { + /// Create a registry with the default hooks. + pub fn new() -> Self { + Default::default() + } + + /// Create an empty registry. + pub fn empty() -> Self { + Self { + table: Default::default(), + } + } + + /// Register a hook under a given name. + pub fn register(&mut self, name: u32, hook: BoxedHook<'a>) { + self.table.insert(name, hook); + } +} + +impl<'a> Default for HookRegistry<'a> { + fn default() -> Self { + // When `LazyCell` gets stabilized (1.81.0), we can use it to avoid unnecessary allocations. + let table = { + let entries: Vec<(u32, BoxedHook)> = vec![ + // Note: To ensure any `fd` value is synced with `zkvm/precompiles/src/io.rs`, + // add an assertion to the test `hook_fds_match` below. + (FD_ECRECOVER_HOOK, Box::new(hook_ecrecover)), + ]; + HashMap::from_iter(entries) + }; + + Self { table } + } +} + +/// Environment that a hook may read from. +pub struct HookEnv<'a, 'b: 'a> { + pub runtime: &'a Runtime<'b>, +} + +pub fn hook_ecrecover(_env: HookEnv, buf: &[u8]) -> Vec> { + assert_eq!( + buf.len(), + 65 + 32, + "ecrecover input should have length 65 + 32" + ); + let (sig, msg_hash) = buf.split_at(65); + let sig: &[u8; 65] = sig.try_into().unwrap(); + let msg_hash: &[u8; 32] = msg_hash.try_into().unwrap(); + + let mut recovery_id = sig[64]; + let mut sig = Signature::from_slice(&sig[..64]).unwrap(); + + if let Some(sig_normalized) = sig.normalize_s() { + sig = sig_normalized; + recovery_id ^= 1 + }; + let recid = RecoveryId::from_byte(recovery_id).expect("Recovery ID is valid"); + + let recovered_key = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid).unwrap(); + let bytes = recovered_key.to_sec1_bytes(); + + let (_, s) = sig.split_scalars(); + let s_inverse = s.invert(); + + vec![bytes.to_vec(), s_inverse.to_bytes().to_vec()] +} + +#[cfg(test)] +pub mod tests { + use crate::{ + runtime::Program, + utils::{self, tests::ECRECOVER_ELF}, + }; + + use super::*; + + #[test] + pub fn hook_fds_match() { + use sp1_zkvm::precompiles::io; + assert_eq!(FD_ECRECOVER_HOOK, io::FD_ECRECOVER_HOOK) + } + + #[test] + pub fn registry_new_is_inhabited() { + assert_ne!(HookRegistry::new().table.len(), 0); + } + + #[test] + pub fn registry_empty_is_empty() { + assert_eq!(HookRegistry::empty().table.len(), 0); + } + + #[test] + fn test_ecrecover_program_prove() { + utils::setup_logger(); + let program = Program::from(ECRECOVER_ELF); + utils::run_test(program).unwrap(); + } +} diff --git a/core/src/runtime/io.rs b/core/src/runtime/io.rs index f404a8a554..a0bb76508f 100644 --- a/core/src/runtime/io.rs +++ b/core/src/runtime/io.rs @@ -8,14 +8,14 @@ use serde::Serialize; use super::Runtime; -impl Read for Runtime { +impl<'a> Read for Runtime<'a> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { self.read_public_values_slice(buf); Ok(buf.len()) } } -impl Runtime { +impl<'a> Runtime<'a> { pub fn write_stdin(&mut self, input: &T) { let mut buf = Vec::new(); bincode::serialize_into(&mut buf, input).expect("serialization failed"); diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index 003d35e8c5..1692181d91 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -1,3 +1,4 @@ +mod hooks; mod instruction; mod io; mod memory; @@ -5,24 +6,28 @@ mod opcode; mod program; mod record; mod register; +mod report; mod state; mod syscall; #[macro_use] mod utils; +mod subproof; +pub use hooks::*; pub use instruction::*; pub use memory::*; pub use opcode::*; pub use program::*; pub use record::*; pub use register::*; +pub use report::*; pub use state::*; +pub use subproof::*; pub use syscall::*; pub use utils::*; use std::collections::hash_map::Entry; use std::collections::HashMap; -use std::fmt::{Display, Formatter, Result as FmtResult}; use std::fs::File; use std::io::BufWriter; use std::io::Write; @@ -30,6 +35,8 @@ use std::sync::Arc; use thiserror::Error; +use crate::alu::create_alu_lookup_id; +use crate::alu::create_alu_lookups; use crate::bytes::NUM_BYTE_LOOKUP_CHANNELS; use crate::memory::MemoryInitializeFinalizeEvent; use crate::utils::SP1CoreOpts; @@ -42,7 +49,7 @@ use crate::{alu::AluEvent, cpu::CpuEvent}; /// /// For more information on the RV32IM instruction set, see the following: /// https://www.cs.sfu.ca/~ashriram/Courses/CS295/assets/notebooks/RISCV/RISCV_CARD.pdf -pub struct Runtime { +pub struct Runtime<'a> { /// The program. pub program: Arc, @@ -87,49 +94,13 @@ pub struct Runtime { pub report: ExecutionReport, /// Whether we should write to the report. - pub should_report: bool, -} - -#[derive(Default, Debug, Clone, PartialEq, Eq)] -pub struct ExecutionReport { - pub instruction_counts: HashMap, - pub syscall_counts: HashMap, -} - -impl ExecutionReport { - pub fn total_instruction_count(&self) -> u64 { - self.instruction_counts.values().sum() - } - - pub fn total_syscall_count(&self) -> u64 { - self.syscall_counts.values().sum() - } -} - -impl Display for ExecutionReport { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - writeln!(f, "Instruction Counts:")?; - let mut sorted_instructions = self.instruction_counts.iter().collect::>(); - - // Sort instructions by opcode name - sorted_instructions.sort_by_key(|&(opcode, _)| opcode.to_string()); - for (opcode, count) in sorted_instructions { - writeln!(f, " {}: {}", opcode, count)?; - } - writeln!(f, "Total Instructions: {}", self.total_instruction_count())?; - - writeln!(f, "Syscall Counts:")?; - let mut sorted_syscalls = self.syscall_counts.iter().collect::>(); + pub print_report: bool, - // Sort syscalls by syscall name - sorted_syscalls.sort_by_key(|&(syscall, _)| format!("{:?}", syscall)); - for (syscall, count) in sorted_syscalls { - writeln!(f, " {}: {}", syscall, count)?; - } - writeln!(f, "Total Syscall Count: {}", self.total_syscall_count())?; + /// Verifier used to sanity check `verify_sp1_proof` during runtime. + pub subproof_verifier: Arc, - Ok(()) - } + /// Registry of hooks, to be invoked by writing to certain file descriptors. + pub hook_registry: HookRegistry<'a>, } #[derive(Error, Debug)] @@ -146,7 +117,7 @@ pub enum ExecutionError { Unimplemented(), } -impl Runtime { +impl<'a> Runtime<'a> { // Create a new runtime from a program. pub fn new(program: Program, opts: SP1CoreOpts) -> Self { // Create a shared reference to the program. @@ -189,11 +160,24 @@ impl Runtime { syscall_map, emit_events: true, max_syscall_cycles, - report: Default::default(), - should_report: false, + report: ExecutionReport::default(), + print_report: false, + subproof_verifier: Arc::new(DefaultSubproofVerifier::new()), + hook_registry: HookRegistry::default(), } } + /// Invokes the hook corresponding to the given file descriptor `fd` with the data `buf`, + /// returning the resulting data. + pub fn hook(&self, fd: u32, buf: &[u8]) -> Vec> { + self.hook_registry.table[&fd](self.hook_env(), buf) + } + + /// Prepare a `HookEnv` for use by hooks. + pub fn hook_env(&self) -> HookEnv { + HookEnv { runtime: self } + } + /// Recover runtime state from a program and existing execution state. pub fn recover(program: Program, state: ExecutionState, opts: SP1CoreOpts) -> Self { let mut runtime = Self::new(program, opts); @@ -445,6 +429,8 @@ impl Runtime { memory_store_value: Option, record: MemoryAccessRecord, exit_code: u32, + lookup_id: usize, + syscall_lookup_id: usize, ) { let cpu_event = CpuEvent { shard, @@ -462,14 +448,25 @@ impl Runtime { memory: memory_store_value, memory_record: record.memory, exit_code, + alu_lookup_id: lookup_id, + syscall_lookup_id, + memory_add_lookup_id: create_alu_lookup_id(), + memory_sub_lookup_id: create_alu_lookup_id(), + branch_lt_lookup_id: create_alu_lookup_id(), + branch_gt_lookup_id: create_alu_lookup_id(), + branch_add_lookup_id: create_alu_lookup_id(), + jump_jal_lookup_id: create_alu_lookup_id(), + jump_jalr_lookup_id: create_alu_lookup_id(), + auipc_lookup_id: create_alu_lookup_id(), }; self.record.cpu_events.push(cpu_event); } /// Emit an ALU event. - fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) { + fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32, lookup_id: usize) { let event = AluEvent { + lookup_id, shard: self.shard(), clk, channel: self.channel(), @@ -477,6 +474,7 @@ impl Runtime { a, b, c, + sub_lookups: create_alu_lookups(), }; match opcode { Opcode::ADD => { @@ -530,10 +528,18 @@ impl Runtime { } /// Set the destination register with the result and emit an ALU event. - fn alu_rw(&mut self, instruction: Instruction, rd: Register, a: u32, b: u32, c: u32) { + fn alu_rw( + &mut self, + instruction: Instruction, + rd: Register, + a: u32, + b: u32, + c: u32, + lookup_id: usize, + ) { self.rw(rd, a); if self.emit_events { - self.emit_alu(self.state.clk, instruction.opcode, a, b, c); + self.emit_alu(self.state.clk, instruction.opcode, a, b, c, lookup_id); } } @@ -586,9 +592,12 @@ impl Runtime { let mut memory_store_value: Option = None; self.memory_accesses = MemoryAccessRecord::default(); - if self.should_report && !self.unconstrained { + let lookup_id = create_alu_lookup_id(); + let syscall_lookup_id = create_alu_lookup_id(); + + if self.print_report && !self.unconstrained { self.report - .instruction_counts + .opcode_counts .entry(instruction.opcode) .and_modify(|c| *c += 1) .or_insert(1); @@ -599,52 +608,52 @@ impl Runtime { Opcode::ADD => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_add(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SUB => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_sub(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::XOR => { (rd, b, c) = self.alu_rr(instruction); a = b ^ c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::OR => { (rd, b, c) = self.alu_rr(instruction); a = b | c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::AND => { (rd, b, c) = self.alu_rr(instruction); a = b & c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shl(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SRL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shr(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SRA => { (rd, b, c) = self.alu_rr(instruction); a = (b as i32).wrapping_shr(c) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLT => { (rd, b, c) = self.alu_rr(instruction); a = if (b as i32) < (c as i32) { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLTU => { (rd, b, c) = self.alu_rr(instruction); a = if b < c { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } // Load instructions. @@ -808,7 +817,7 @@ impl Runtime { b = self.rr(Register::X10, MemoryAccessPosition::B); let syscall = SyscallCode::from_u32(syscall_id); - if self.should_report && !self.unconstrained { + if self.print_report && !self.unconstrained { self.report .syscall_counts .entry(syscall) @@ -818,6 +827,7 @@ impl Runtime { let syscall_impl = self.get_syscall(syscall).cloned(); let mut precompile_rt = SyscallContext::new(self); + precompile_rt.syscall_lookup_id = syscall_lookup_id; let (precompile_next_pc, precompile_cycles, returned_exit_code) = if let Some(syscall_impl) = syscall_impl { // Executing a syscall optionally returns a value to write to the t0 register. @@ -862,22 +872,22 @@ impl Runtime { Opcode::MUL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_mul(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULH => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul((c as i32) as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULHU => { (rd, b, c) = self.alu_rr(instruction); a = ((b as u64).wrapping_mul(c as u64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULHSU => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul(c as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::DIV => { (rd, b, c) = self.alu_rr(instruction); @@ -886,7 +896,7 @@ impl Runtime { } else { a = (b as i32).wrapping_div(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::DIVU => { (rd, b, c) = self.alu_rr(instruction); @@ -895,7 +905,7 @@ impl Runtime { } else { a = b.wrapping_div(c); } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::REM => { (rd, b, c) = self.alu_rr(instruction); @@ -904,7 +914,7 @@ impl Runtime { } else { a = (b as i32).wrapping_rem(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::REMU => { (rd, b, c) = self.alu_rr(instruction); @@ -913,7 +923,7 @@ impl Runtime { } else { a = b.wrapping_rem(c); } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } // See https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#instruction-aliases @@ -950,6 +960,8 @@ impl Runtime { memory_store_value, self.memory_accesses, exit_code, + lookup_id, + syscall_lookup_id, ); }; Ok(()) @@ -985,6 +997,7 @@ impl Runtime { /// Execute up to `self.shard_batch_size` cycles, returning the events emitted and whether the program ended. pub fn execute_record(&mut self) -> Result<(ExecutionRecord, bool), ExecutionError> { self.emit_events = true; + self.print_report = true; let done = self.execute()?; Ok((std::mem::take(&mut self.record), done)) } @@ -992,6 +1005,7 @@ impl Runtime { /// Execute up to `self.shard_batch_size` cycles, returning a copy of the prestate and whether the program ended. pub fn execute_state(&mut self) -> Result<(ExecutionState, bool), ExecutionError> { self.emit_events = false; + self.print_report = false; let state = self.state.clone(); let done = self.execute()?; Ok((state, done)) @@ -1001,7 +1015,7 @@ impl Runtime { self.state.clk = 0; self.state.channel = 0; - tracing::info!("loading memory image"); + tracing::debug!("loading memory image"); for (addr, value) in self.program.memory_image.iter() { self.state.memory.insert( *addr, @@ -1017,20 +1031,18 @@ impl Runtime { self.record .memory_initialize_events .push(MemoryInitializeFinalizeEvent::initialize(0, 0, true)); - - tracing::info!("starting execution"); } pub fn run_untraced(&mut self) -> Result<(), ExecutionError> { self.emit_events = false; - self.should_report = true; + self.print_report = true; while !self.execute()? {} Ok(()) } pub fn run(&mut self) -> Result<(), ExecutionError> { self.emit_events = true; - self.should_report = true; + self.print_report = true; while !self.execute()? {} Ok(()) } @@ -1074,11 +1086,6 @@ impl Runtime { } fn postprocess(&mut self) { - tracing::info!( - "finished execution clk = {} pc = 0x{:x?}", - self.state.global_clk, - self.state.pc - ); // Flush remaining stdout/stderr for (fd, buf) in self.io_buf.iter() { if !buf.is_empty() { @@ -1099,6 +1106,14 @@ impl Runtime { buf.flush().unwrap(); } + // Ensure that all proofs and input bytes were read, otherwise warn the user. + if self.state.proof_stream_ptr != self.state.proof_stream.len() { + panic!("Not all proofs were read. Proving will fail during recursion. Did you pass too many proofs in or forget to call verify_sp1_proof?"); + } + if self.state.input_stream_ptr != self.state.input_stream.len() { + log::warn!("Not all input bytes were read."); + } + // SECTION: Set up all MemoryInitializeFinalizeEvents needed for memory argument. let memory_finalize_events = &mut self.record.memory_finalize_events; @@ -1111,7 +1126,7 @@ impl Runtime { None => &MemoryRecord { value: 0, shard: 0, - timestamp: 0, + timestamp: 1, }, }; memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize_from_record( @@ -1171,6 +1186,13 @@ pub mod tests { Program::from(PANIC_ELF) } + fn _assert_send() {} + + /// Runtime needs to be Send so we can use it across async calls. + fn _assert_runtime_is_send() { + _assert_send::(); + } + #[test] fn test_simple_program_run() { let program = simple_program(); @@ -1179,58 +1201,6 @@ pub mod tests { assert_eq!(runtime.register(Register::X31), 42); } - #[test] - fn test_ssz_withdrawals_program_run_report() { - let program = ssz_withdrawals_program(); - let mut runtime = Runtime::new(program, SP1CoreOpts::default()); - runtime.run().unwrap(); - assert_eq!(runtime.report, { - use super::Opcode::*; - use super::SyscallCode::*; - super::ExecutionReport { - instruction_counts: [ - (LB, 10723), - (DIVU, 6), - (LW, 237094), - (JALR, 38749), - (XOR, 242242), - (BEQ, 26917), - (AND, 151701), - (SB, 58448), - (MUL, 4036), - (SLTU, 16766), - (ADD, 583439), - (JAL, 5372), - (LBU, 57950), - (SRL, 293010), - (SW, 312781), - (ECALL, 2264), - (BLTU, 43457), - (BGEU, 5917), - (BLT, 1141), - (SUB, 12382), - (BGE, 237), - (MULHU, 1152), - (BNE, 51442), - (AUIPC, 19488), - (OR, 301944), - (SLL, 278698), - ] - .into(), - syscall_counts: [ - (COMMIT_DEFERRED_PROOFS, 8), - (SHA_EXTEND, 1091), - (COMMIT, 8), - (WRITE, 65), - (SHA_COMPRESS, 1091), - (HALT, 1), - ] - .into(), - } - }); - assert_eq!(runtime.report.total_instruction_count(), 2757356); - } - #[test] #[should_panic] fn test_panic() { diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 67a2464f99..7742c2aa6c 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -17,7 +17,6 @@ use crate::cpu::CpuEvent; use crate::runtime::MemoryInitializeFinalizeEvent; use crate::runtime::MemoryRecordEnum; use crate::stark::MachineRecord; -use crate::syscall::precompiles::blake3::Blake3CompressInnerEvent; use crate::syscall::precompiles::edwards::EdDecompressEvent; use crate::syscall::precompiles::keccak256::KeccakPermuteEvent; use crate::syscall::precompiles::sha256::{ShaCompressEvent, ShaExtendEvent}; @@ -87,8 +86,6 @@ pub struct ExecutionRecord { pub k256_decompress_events: Vec, - pub blake3_compress_inner_events: Vec, - pub bls12381_add_events: Vec, pub bls12381_double_events: Vec, @@ -103,6 +100,8 @@ pub struct ExecutionRecord { /// The public values. pub public_values: PublicValues, + + pub nonce_lookup: HashMap, } pub struct ShardingConfig { @@ -220,10 +219,6 @@ impl MachineRecord for ExecutionRecord { "k256_decompress_events".to_string(), self.k256_decompress_events.len(), ); - stats.insert( - "blake3_compress_inner_events".to_string(), - self.blake3_compress_inner_events.len(), - ); stats.insert( "bls12381_add_events".to_string(), self.bls12381_add_events.len(), @@ -272,8 +267,6 @@ impl MachineRecord for ExecutionRecord { .append(&mut other.bn254_double_events); self.k256_decompress_events .append(&mut other.k256_decompress_events); - self.blake3_compress_inner_events - .append(&mut other.blake3_compress_inner_events); self.bls12381_add_events .append(&mut other.bls12381_add_events); self.bls12381_double_events @@ -356,22 +349,15 @@ impl MachineRecord for ExecutionRecord { } } - // Shard all the other events according to the configuration. - // Shard the ADD events. for (add_chunk, shard) in take(&mut self.add_events) .chunks_mut(config.add_len) .zip(shards.iter_mut()) { shard.add_events.extend_from_slice(add_chunk); - } - - // Shard the MUL events. - for (mul_chunk, shard) in take(&mut self.mul_events) - .chunks_mut(config.mul_len) - .zip(shards.iter_mut()) - { - shard.mul_events.extend_from_slice(mul_chunk); + for (i, event) in add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the SUB events. @@ -380,6 +366,21 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.sub_events.extend_from_slice(sub_chunk); + for (i, event) in sub_chunk.iter().enumerate() { + self.nonce_lookup + .insert(event.lookup_id, shard.add_events.len() as u32 + i as u32); + } + } + + // Shard the MUL events. + for (mul_chunk, shard) in take(&mut self.mul_events) + .chunks_mut(config.mul_len) + .zip(shards.iter_mut()) + { + shard.mul_events.extend_from_slice(mul_chunk); + for (i, event) in mul_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the bitwise events. @@ -388,6 +389,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bitwise_events.extend_from_slice(bitwise_chunk); + for (i, event) in bitwise_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the shift left events. @@ -396,6 +400,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.shift_left_events.extend_from_slice(shift_left_chunk); + for (i, event) in shift_left_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the shift right events. @@ -406,6 +413,9 @@ impl MachineRecord for ExecutionRecord { shard .shift_right_events .extend_from_slice(shift_right_chunk); + for (i, event) in shift_right_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the divrem events. @@ -414,6 +424,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.divrem_events.extend_from_slice(divrem_chunk); + for (i, event) in divrem_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the LT events. @@ -422,6 +435,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.lt_events.extend_from_slice(lt_chunk); + for (i, event) in lt_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Keccak-256 permute events. @@ -430,6 +446,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.keccak_permute_events.extend_from_slice(keccak_chunk); + for (i, event) in keccak_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 24) as u32); + } } // secp256k1 curve add events. @@ -440,6 +459,9 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_add_events .extend_from_slice(secp256k1_add_chunk); + for (i, event) in secp256k1_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // secp256k1 curve double events. @@ -450,6 +472,9 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_double_events .extend_from_slice(secp256k1_double_chunk); + for (i, event) in secp256k1_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // bn254 curve add events. @@ -458,6 +483,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bn254_add_events.extend_from_slice(bn254_add_chunk); + for (i, event) in bn254_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // bn254 curve double events. @@ -468,6 +496,9 @@ impl MachineRecord for ExecutionRecord { shard .bn254_double_events .extend_from_slice(bn254_double_chunk); + for (i, event) in bn254_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // BLS12-381 curve add events. @@ -478,6 +509,9 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_add_events .extend_from_slice(bls12381_add_chunk); + for (i, event) in bls12381_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // BLS12-381 curve double events. @@ -488,6 +522,9 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_double_events .extend_from_slice(bls12381_double_chunk); + for (i, event) in bls12381_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Put the precompile events in the first shard. @@ -495,38 +532,58 @@ impl MachineRecord for ExecutionRecord { // SHA-256 extend events. first.sha_extend_events = std::mem::take(&mut self.sha_extend_events); + for (i, event) in first.sha_extend_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 48) as u32); + } // SHA-256 compress events. first.sha_compress_events = std::mem::take(&mut self.sha_compress_events); + for (i, event) in first.sha_compress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 80) as u32); + } // Edwards curve add events. first.ed_add_events = std::mem::take(&mut self.ed_add_events); + for (i, event) in first.ed_add_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Edwards curve decompress events. first.ed_decompress_events = std::mem::take(&mut self.ed_decompress_events); + for (i, event) in first.ed_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // K256 curve decompress events. first.k256_decompress_events = std::mem::take(&mut self.k256_decompress_events); - - // Blake3 compress events . - first.blake3_compress_inner_events = std::mem::take(&mut self.blake3_compress_inner_events); + for (i, event) in first.k256_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Uint256 mul arithmetic events. first.uint256_mul_events = std::mem::take(&mut self.uint256_mul_events); + for (i, event) in first.uint256_mul_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Bls12-381 decompress events . first.bls12381_decompress_events = std::mem::take(&mut self.bls12381_decompress_events); + for (i, event) in first.bls12381_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } - // Put the memory records in the last shard. - let last_shard = shards.last_mut().unwrap(); - - last_shard + first .memory_initialize_events .extend_from_slice(&self.memory_initialize_events); - last_shard + first .memory_finalize_events .extend_from_slice(&self.memory_finalize_events); + // Copy the nonce lookup to all shards. + for shard in shards.iter_mut() { + shard.nonce_lookup.clone_from(&self.nonce_lookup); + } + shards } diff --git a/core/src/runtime/report.rs b/core/src/runtime/report.rs new file mode 100644 index 0000000000..62e85f108b --- /dev/null +++ b/core/src/runtime/report.rs @@ -0,0 +1,113 @@ +use std::collections::hash_map::Entry; +use std::collections::HashMap; +use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::hash::Hash; +use std::ops::{Add, AddAssign}; + +use super::*; + +#[derive(Default, Debug, Clone, PartialEq, Eq)] +pub struct ExecutionReport { + pub opcode_counts: HashMap, + pub syscall_counts: HashMap, +} + +impl ExecutionReport { + /// Compute the total number of instructions run during the execution. + pub fn total_instruction_count(&self) -> u64 { + self.opcode_counts.values().sum() + } + + /// Compute the total number of syscalls made during the execution. + pub fn total_syscall_count(&self) -> u64 { + self.syscall_counts.values().sum() + } + + /// Returns sorted and formatted rows of a table of counts (e.g. `opcode_counts`). + /// + /// The table is sorted first by count (descending) and then by label (ascending). + /// The first column consists of the counts, is right-justified, and is padded precisely + /// enough to fit all the numbers. The second column consists of the labels (e.g. `OpCode`s). + /// The columns are separated by a single space character. + pub fn sorted_table_lines(table: &HashMap) -> Vec + where + K: Ord + Display, + V: Ord + Display, + { + // This function could be optimized here and there, + // for example by pre-allocating all `Vec`s, or by using less memory. + let mut lines = Vec::with_capacity(table.len()); + let mut entries = table.iter().collect::>(); + // Sort table by count (descending), then the name order (ascending). + entries.sort_unstable_by(|a, b| a.1.cmp(b.1).reverse().then_with(|| a.0.cmp(b.0))); + // Convert counts to `String`s to prepare them for printing and to measure their width. + let table_with_string_counts = entries + .into_iter() + .map(|(label, ct)| (label.to_string().to_lowercase(), ct.to_string())) + .collect::>(); + // Calculate width for padding the counts. + let width = table_with_string_counts + .first() + .map(|(_, b)| b.len()) + .unwrap_or_default(); + for (label, count) in table_with_string_counts { + lines.push(format!("{count:>width$} {label}")); + } + lines + } +} + +/// Combines two `HashMap`s together. If a key is in both maps, the values are added together. +fn hashmap_add_assign(lhs: &mut HashMap, rhs: HashMap) +where + K: Eq + Hash, + V: AddAssign, +{ + for (k, v) in rhs.into_iter() { + // Can't use `.and_modify(...).or_insert(...)` because we want to use `v` in both places. + match lhs.entry(k) { + Entry::Occupied(e) => *e.into_mut() += v, + Entry::Vacant(e) => drop(e.insert(v)), + } + } +} + +impl AddAssign for ExecutionReport { + fn add_assign(&mut self, rhs: Self) { + hashmap_add_assign(&mut self.opcode_counts, rhs.opcode_counts); + hashmap_add_assign(&mut self.syscall_counts, rhs.syscall_counts); + } +} + +impl Add for ExecutionReport { + type Output = Self; + + fn add(mut self, rhs: Self) -> Self::Output { + self += rhs; + self + } +} + +impl Display for ExecutionReport { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + writeln!( + f, + "opcode counts ({} total instructions):", + self.total_instruction_count() + )?; + for line in Self::sorted_table_lines(&self.opcode_counts) { + writeln!(f, " {line}")?; + } + + writeln!( + f, + "syscall counts ({} total syscall instructions):", + self.total_syscall_count() + )?; + for line in Self::sorted_table_lines(&self.syscall_counts) { + writeln!(f, " {line}")?; + } + + Ok(()) + } +} diff --git a/core/src/runtime/subproof.rs b/core/src/runtime/subproof.rs new file mode 100644 index 0000000000..701608646d --- /dev/null +++ b/core/src/runtime/subproof.rs @@ -0,0 +1,68 @@ +use std::sync::atomic::AtomicBool; + +use crate::{ + stark::{MachineVerificationError, ShardProof, StarkVerifyingKey}, + utils::BabyBearPoseidon2, +}; + +/// Verifier used in runtime when `sp1_zkvm::precompiles::verify::verify_sp1_proof` is called. This +/// is then used to sanity check that the user passed in the correct proof; the actual constraints +/// happen in the recursion layer. +/// +/// This needs to be passed in rather than written directly since the actual implementation relies +/// on crates in recursion that depend on sp1-core. +pub trait SubproofVerifier: Sync + Send { + fn verify_deferred_proof( + &self, + proof: &ShardProof, + vk: &StarkVerifyingKey, + vk_hash: [u32; 8], + committed_value_digest: [u32; 8], + ) -> Result<(), MachineVerificationError>; +} + +/// A dummy verifier which prints a warning on the first proof and does nothing else. +#[derive(Default)] +pub struct DefaultSubproofVerifier { + printed: AtomicBool, +} + +impl DefaultSubproofVerifier { + pub fn new() -> Self { + Self { + printed: AtomicBool::new(false), + } + } +} + +impl SubproofVerifier for DefaultSubproofVerifier { + fn verify_deferred_proof( + &self, + _proof: &ShardProof, + _vk: &StarkVerifyingKey, + _vk_hash: [u32; 8], + _committed_value_digest: [u32; 8], + ) -> Result<(), MachineVerificationError> { + if !self.printed.load(std::sync::atomic::Ordering::SeqCst) { + tracing::info!("Not verifying sub proof during runtime"); + self.printed + .store(true, std::sync::atomic::Ordering::SeqCst); + } + Ok(()) + } +} + +/// A dummy verifier which does nothing. +pub struct NoOpSubproofVerifier; + +impl SubproofVerifier for NoOpSubproofVerifier { + fn verify_deferred_proof( + &self, + _proof: &ShardProof, + _vk: &StarkVerifyingKey, + _vk_hash: [u32; 8], + _committed_value_digest: [u32; 8], + ) -> Result<(), MachineVerificationError> { + Ok(()) + } +} diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index c320c7e28e..7fd4909534 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use strum_macros::EnumIter; use crate::runtime::{Register, Runtime}; -use crate::stark::Blake3CompressInnerChip; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -68,9 +67,6 @@ pub enum SyscallCode { /// Executes the `SECP256K1_DECOMPRESS` precompile. SECP256K1_DECOMPRESS = 0x00_00_01_0C, - /// Executes the `BLAKE3_COMPRESS_INNER` precompile. - BLAKE3_COMPRESS_INNER = 0x00_38_01_0D, - /// Executes the `BN254_ADD` precompile. BN254_ADD = 0x00_01_01_0E, @@ -121,7 +117,6 @@ impl SyscallCode { 0x00_01_01_0A => SyscallCode::SECP256K1_ADD, 0x00_00_01_0B => SyscallCode::SECP256K1_DOUBLE, 0x00_00_01_0C => SyscallCode::SECP256K1_DECOMPRESS, - 0x00_38_01_0D => SyscallCode::BLAKE3_COMPRESS_INNER, 0x00_01_01_0E => SyscallCode::BN254_ADD, 0x00_00_01_0F => SyscallCode::BN254_DOUBLE, 0x00_01_01_1E => SyscallCode::BLS12381_ADD, @@ -172,18 +167,19 @@ pub trait Syscall: Send + Sync { } /// A runtime for syscalls that is protected so that developers cannot arbitrarily modify the runtime. -pub struct SyscallContext<'a> { +pub struct SyscallContext<'a, 'b: 'a> { current_shard: u32, pub clk: u32, pub(crate) next_pc: u32, /// This is the exit_code used for the HALT syscall pub(crate) exit_code: u32, - pub(crate) rt: &'a mut Runtime, + pub(crate) rt: &'a mut Runtime<'b>, + pub syscall_lookup_id: usize, } -impl<'a> SyscallContext<'a> { - pub fn new(runtime: &'a mut Runtime) -> Self { +impl<'a, 'b> SyscallContext<'a, 'b> { + pub fn new(runtime: &'a mut Runtime<'b>) -> Self { let current_shard = runtime.shard(); let clk = runtime.state.clk; Self { @@ -192,6 +188,7 @@ impl<'a> SyscallContext<'a> { next_pc: runtime.state.pc.wrapping_add(4), exit_code: 0, rt: runtime, + syscall_lookup_id: 0, } } @@ -304,10 +301,6 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BN254_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert( - SyscallCode::BLAKE3_COMPRESS_INNER, - Arc::new(Blake3CompressInnerChip::new()), - ); syscall_map.insert( SyscallCode::BLS12381_ADD, Arc::new(WeierstrassAddAssignChip::::new()), @@ -316,10 +309,6 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BLS12381_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert( - SyscallCode::BLAKE3_COMPRESS_INNER, - Arc::new(Blake3CompressInnerChip::new()), - ); syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); syscall_map.insert( SyscallCode::ENTER_UNCONSTRAINED, @@ -359,10 +348,6 @@ mod tests { fn test_syscalls_in_default_map() { let default_syscall_map = default_syscall_map(); for code in SyscallCode::iter() { - if code == SyscallCode::BLAKE3_COMPRESS_INNER { - // Blake3 is currently disabled. - continue; - } default_syscall_map.get(&code).unwrap(); } } @@ -412,9 +397,6 @@ mod tests { SyscallCode::SECP256K1_DOUBLE => { assert_eq!(code as u32, sp1_zkvm::syscalls::SECP256K1_DOUBLE) } - SyscallCode::BLAKE3_COMPRESS_INNER => { - assert_eq!(code as u32, sp1_zkvm::syscalls::BLAKE3_COMPRESS_INNER) - } SyscallCode::BLS12381_ADD => { assert_eq!(code as u32, sp1_zkvm::syscalls::BLS12381_ADD) } diff --git a/core/src/runtime/utils.rs b/core/src/runtime/utils.rs index 009353abbb..41e132bccc 100644 --- a/core/src/runtime/utils.rs +++ b/core/src/runtime/utils.rs @@ -30,7 +30,7 @@ macro_rules! assert_valid_memory_access { }; } -impl Runtime { +impl<'a> Runtime<'a> { #[inline] pub fn log(&mut self, instruction: &Instruction) { // Write the current program counter to the trace buffer for the cycle tracer. diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index dc181b18d2..558ccec5ab 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -21,7 +21,6 @@ pub(crate) mod riscv_chips { pub use crate::cpu::CpuChip; pub use crate::memory::MemoryChip; pub use crate::program::ProgramChip; - pub use crate::syscall::precompiles::blake3::Blake3CompressInnerChip; pub use crate::syscall::precompiles::edwards::EdAddAssignChip; pub use crate::syscall::precompiles::edwards::EdDecompressChip; pub use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -88,8 +87,6 @@ pub enum RiscvAir { Secp256k1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Keccak permutation. KeccakP(KeccakPermuteChip), - /// A precompile for the Blake3 compression function. (Disabled by default.) - Blake3Compress(Blake3CompressInnerChip), /// A precompile for addition on the Elliptic curve bn254. Bn254Add(WeierstrassAddAssignChip>), /// A precompile for doubling a point on the Elliptic curve bn254. @@ -152,12 +149,12 @@ impl RiscvAir { chips.push(RiscvAir::Uint256Mul(uint256_mul)); let bls12381_decompress = WeierstrassDecompressChip::>::new(); chips.push(RiscvAir::Bls12381Decompress(bls12381_decompress)); + let div_rem = DivRemChip::default(); + chips.push(RiscvAir::DivRem(div_rem)); let add = AddSubChip::default(); chips.push(RiscvAir::Add(add)); let bitwise = BitwiseChip::default(); chips.push(RiscvAir::Bitwise(bitwise)); - let div_rem = DivRemChip::default(); - chips.push(RiscvAir::DivRem(div_rem)); let mul = MulChip::default(); chips.push(RiscvAir::Mul(mul)); let shift_right = ShiftRightChip::default(); diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index 4a31646431..534c444735 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -61,12 +61,24 @@ where where A: MachineAir + Air> + Air>, { - // Todo: correct values let mut builder = InteractionBuilder::new(air.preprocessed_width(), air.width()); air.eval(&mut builder); let (sends, receives) = builder.interactions(); - // TODO: enable different numbers of public values. + let nb_byte_sends = sends + .iter() + .filter(|s| s.kind == InteractionKind::Byte) + .count(); + let nb_byte_receives = receives + .iter() + .filter(|r| r.kind == InteractionKind::Byte) + .count(); + tracing::debug!( + "chip {} has {} byte interactions", + air.name(), + nb_byte_sends + nb_byte_receives + ); + let mut max_constraint_degree = get_max_constraint_degree(&air, air.preprocessed_width(), PROOF_MAX_NUM_PVS); @@ -101,7 +113,7 @@ where pub fn generate_permutation_trace>( &self, preprocessed: Option<&RowMajorMatrix>, - main: &mut RowMajorMatrix, + main: &RowMajorMatrix, random_elements: &[EF], ) -> RowMajorMatrix where diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 672226030b..8891fdc147 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -259,7 +259,7 @@ impl>> StarkMachine { // Display some statistics about the workload. let stats = record.stats(); - log::info!("shard: {:?}", stats); + log::debug!("shard: {:?}", stats); // For each chip, shard the events into segments. record.shard(config) @@ -473,6 +473,8 @@ pub enum MachineVerificationError { DebugInteractionsFailed, EmptyProof, InvalidPublicValues(&'static str), + TooManyShards, + InvalidChipOccurence(String), } impl Debug for MachineVerificationError { @@ -499,6 +501,12 @@ impl Debug for MachineVerificationError { MachineVerificationError::InvalidPublicValues(s) => { write!(f, "Invalid public values: {}", s) } + MachineVerificationError::TooManyShards => { + write!(f, "Too many shards") + } + MachineVerificationError::InvalidChipOccurence(s) => { + write!(f, "Invalid chip occurence: {}", s) + } } } } diff --git a/core/src/stark/permutation.rs b/core/src/stark/permutation.rs index 865397d5d4..ce27b9e9e7 100644 --- a/core/src/stark/permutation.rs +++ b/core/src/stark/permutation.rs @@ -61,7 +61,7 @@ pub const fn permutation_trace_width(num_interactions: usize, batch_size: usize) /// /// The permutation trace has (N+1)*EF::NUM_COLS columns, where N is the number of interactions in /// the chip. -pub(crate) fn generate_permutation_trace>( +pub fn generate_permutation_trace>( sends: &[Interaction], receives: &[Interaction], preprocessed: Option<&RowMajorMatrix>, diff --git a/core/src/syscall/precompiles/blake3/compress/air.rs b/core/src/syscall/precompiles/blake3/compress/air.rs deleted file mode 100644 index a5876866e3..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/air.rs +++ /dev/null @@ -1,235 +0,0 @@ -use core::borrow::Borrow; - -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; -use p3_matrix::Matrix; - -use super::columns::{Blake3CompressInnerCols, NUM_BLAKE3_COMPRESS_INNER_COLS}; -use super::g::GOperation; -use super::{ - Blake3CompressInnerChip, G_INDEX, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, - NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, -}; -use crate::air::{BaseAirBuilder, SP1AirBuilder, WORD_SIZE}; -use crate::runtime::SyscallCode; - -impl BaseAir for Blake3CompressInnerChip { - fn width(&self) -> usize { - NUM_BLAKE3_COMPRESS_INNER_COLS - } -} - -impl Air for Blake3CompressInnerChip -where - AB: SP1AirBuilder, -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local: &Blake3CompressInnerCols = (*local).borrow(); - let next: &Blake3CompressInnerCols = (*next).borrow(); - - self.constrain_control_flow_flags(builder, local, next); - - self.constrain_memory(builder, local); - - self.constrain_g_operation(builder, local); - - // TODO: constraint ecall_receive column. - // TODO: constraint clk column to increment by 1 within same invocation of syscall. - builder.receive_syscall( - local.shard, - local.channel, - local.clk, - AB::F::from_canonical_u32(SyscallCode::BLAKE3_COMPRESS_INNER.syscall_id()), - local.state_ptr, - local.message_ptr, - local.ecall_receive, - ); - } -} - -impl Blake3CompressInnerChip { - /// Constrains the given index is correct for the given selector. The `selector` is an - /// `n`-dimensional boolean array whose `i`-th element is true if and only if the index is `i`. - fn constrain_index_selector( - &self, - builder: &mut AB, - selector: &[AB::Var], - index: AB::Var, - is_real: AB::Var, - ) { - let mut acc: AB::Expr = AB::F::zero().into(); - for i in 0..selector.len() { - acc += selector[i].into(); - builder.assert_bool(selector[i]) - } - builder - .when(is_real) - .assert_eq(acc, AB::F::from_canonical_usize(1)); - for i in 0..selector.len() { - builder - .when(selector[i]) - .assert_eq(index, AB::F::from_canonical_usize(i)); - } - } - - /// Constrains the control flow flags such as the operation index and the round index. - fn constrain_control_flow_flags( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - next: &Blake3CompressInnerCols, - ) { - // If this is the i-th operation, then the next row should be the (i+1)-th operation. - for i in 0..OPERATION_COUNT { - builder.when_transition().when(next.is_real).assert_eq( - local.is_operation_index_n[i], - next.is_operation_index_n[(i + 1) % OPERATION_COUNT], - ); - } - - // If this is the last operation, the round index should be incremented. Otherwise, the - // round index should remain the same. - for i in 0..OPERATION_COUNT { - if i + 1 < OPERATION_COUNT { - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .assert_eq(local.round_index, next.round_index); - } else { - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .when_not(local.is_round_index_n[ROUND_COUNT - 1]) - .assert_eq( - local.round_index + AB::F::from_canonical_u16(1), - next.round_index, - ); - - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .when(local.is_round_index_n[ROUND_COUNT - 1]) - .assert_zero(next.round_index); - } - } - } - - /// Constrain the memory access for the state and the message. - fn constrain_memory( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - ) { - // Calculate the 4 indices to read from the state. This corresponds to a, b, c, and d. - for i in 0..NUM_STATE_WORDS_PER_CALL { - let index_to_read = { - self.constrain_index_selector( - builder, - &local.is_operation_index_n, - local.operation_index, - local.is_real, - ); - - let mut acc = AB::Expr::from_canonical_usize(0); - for operation in 0..OPERATION_COUNT { - acc += AB::Expr::from_canonical_usize(G_INDEX[operation][i]) - * local.is_operation_index_n[operation]; - } - acc - }; - builder.assert_eq(local.state_index[i], index_to_read); - } - - // Read & write the state. - for i in 0..NUM_STATE_WORDS_PER_CALL { - builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.state_ptr + local.state_index[i] * AB::F::from_canonical_usize(WORD_SIZE), - &local.state_reads_writes[i], - local.is_real, - ); - } - - // Calculate the indices to read from the message. - for i in 0..NUM_MSG_WORDS_PER_CALL { - let index_to_read = { - self.constrain_index_selector( - builder, - &local.is_round_index_n, - local.round_index, - local.is_real, - ); - - let mut acc = AB::Expr::from_canonical_usize(0); - - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - acc += - AB::Expr::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]) - * local.is_operation_index_n[operation] - * local.is_round_index_n[round]; - } - } - acc - }; - builder.assert_eq(local.msg_schedule[i], index_to_read); - } - - // Read the message. - for i in 0..NUM_MSG_WORDS_PER_CALL { - builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.message_ptr + local.msg_schedule[i] * AB::F::from_canonical_usize(WORD_SIZE), - &local.message_reads[i], - local.is_real, - ); - } - } - - /// Constrains the input and the output of the `g` operation. - fn constrain_g_operation( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - ) { - builder.assert_bool(local.is_real); - - // Call g and write the result to the state. - { - let input = [ - local.state_reads_writes[0].prev_value, - local.state_reads_writes[1].prev_value, - local.state_reads_writes[2].prev_value, - local.state_reads_writes[3].prev_value, - local.message_reads[0].access.value, - local.message_reads[1].access.value, - ]; - - // Call the g function. - GOperation::::eval( - builder, - input, - local.g, - local.shard, - local.channel, - local.is_real, - ); - - // Finally, the results of the g function should be written to the memory. - for i in 0..NUM_STATE_WORDS_PER_CALL { - for j in 0..WORD_SIZE { - builder.when(local.is_real).assert_eq( - local.state_reads_writes[i].access.value[j], - local.g.result[i][j], - ); - } - } - } - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/columns.rs b/core/src/syscall/precompiles/blake3/compress/columns.rs deleted file mode 100644 index bf7bbe4e1e..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/columns.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::mem::size_of; - -use sp1_derive::AlignedBorrow; - -use crate::memory::MemoryReadCols; -use crate::memory::MemoryReadWriteCols; - -use super::g::GOperation; -use super::NUM_MSG_WORDS_PER_CALL; -use super::NUM_STATE_WORDS_PER_CALL; -use super::OPERATION_COUNT; -use super::ROUND_COUNT; - -pub const NUM_BLAKE3_COMPRESS_INNER_COLS: usize = size_of::>(); - -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct Blake3CompressInnerCols { - pub shard: T, - pub channel: T, - pub clk: T, - pub ecall_receive: T, - - /// The pointer to the state. - pub state_ptr: T, - - /// The pointer to the message. - pub message_ptr: T, - - /// Reads and writes a part of the state. - pub state_reads_writes: [MemoryReadWriteCols; NUM_STATE_WORDS_PER_CALL], - - /// Reads a part of the message. - pub message_reads: [MemoryReadCols; NUM_MSG_WORDS_PER_CALL], - - /// Indicates which call of `g` is being performed. - pub operation_index: T, - pub is_operation_index_n: [T; OPERATION_COUNT], - - /// Indicates which call of `round` is being performed. - pub round_index: T, - pub is_round_index_n: [T; ROUND_COUNT], - - /// The indices to pass to `g`. - pub state_index: [T; NUM_STATE_WORDS_PER_CALL], - - /// The two values from `MSG_SCHEDULE` to pass to `g`. - pub msg_schedule: [T; NUM_MSG_WORDS_PER_CALL], - - /// The `g` operation to perform. - pub g: GOperation, - - /// Indicates if the current call is real or not. - pub is_real: T, -} diff --git a/core/src/syscall/precompiles/blake3/compress/execute.rs b/core/src/syscall/precompiles/blake3/compress/execute.rs deleted file mode 100644 index 35298b0415..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/execute.rs +++ /dev/null @@ -1,76 +0,0 @@ -use crate::runtime::Syscall; -use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; -use crate::syscall::precompiles::blake3::{ - g_func, Blake3CompressInnerChip, Blake3CompressInnerEvent, G_INDEX, MSG_SCHEDULE, - NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, -}; -use crate::syscall::precompiles::SyscallContext; - -impl Syscall for Blake3CompressInnerChip { - fn num_extra_cycles(&self) -> u32 { - (ROUND_COUNT * OPERATION_COUNT) as u32 - } - - fn execute(&self, rt: &mut SyscallContext, arg1: u32, arg2: u32) -> Option { - let state_ptr = arg1; - let message_ptr = arg2; - - let start_clk = rt.clk; - let mut message_reads = - [[[MemoryReadRecord::default(); NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT]; - let mut state_writes = [[[MemoryWriteRecord::default(); NUM_STATE_WORDS_PER_CALL]; - OPERATION_COUNT]; ROUND_COUNT]; - - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - let state_index = G_INDEX[operation]; - let message_index: [usize; NUM_MSG_WORDS_PER_CALL] = [ - MSG_SCHEDULE[round][2 * operation], - MSG_SCHEDULE[round][2 * operation + 1], - ]; - - let mut input = vec![]; - // Read the input to g. - { - for index in state_index.iter() { - input.push(rt.word_unsafe(state_ptr + (*index as u32) * 4)); - } - for i in 0..NUM_MSG_WORDS_PER_CALL { - let (record, value) = rt.mr(message_ptr + (message_index[i] as u32) * 4); - message_reads[round][operation][i] = record; - input.push(value); - } - } - - // Call g. - let results = g_func(input.try_into().unwrap()); - - // Write the state. - for i in 0..NUM_STATE_WORDS_PER_CALL { - state_writes[round][operation][i] = - rt.mw(state_ptr + (state_index[i] as u32) * 4, results[i]); - } - - // Increment the clock for the next call of g. - rt.clk += 1; - } - } - - let shard = rt.current_shard(); - let channel = rt.current_channel(); - - rt.record_mut() - .blake3_compress_inner_events - .push(Blake3CompressInnerEvent { - shard, - channel, - clk: start_clk, - state_ptr, - message_reads, - state_writes, - message_ptr, - }); - - None - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/g.rs b/core/src/syscall/precompiles/blake3/compress/g.rs deleted file mode 100644 index 06e8c30348..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/g.rs +++ /dev/null @@ -1,277 +0,0 @@ -use p3_field::Field; -use sp1_derive::AlignedBorrow; - -use crate::air::SP1AirBuilder; -use crate::air::Word; -use crate::air::WORD_SIZE; -use crate::operations::AddOperation; -use crate::operations::FixedRotateRightOperation; -use crate::operations::XorOperation; -use crate::runtime::ExecutionRecord; - -use super::g_func; -/// A set of columns needed to compute the `g` of the input state. -/// ``` ignore -/// fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) { -/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(x); -/// state[d] = (state[d] ^ state[a]).rotate_right(16); -/// state[c] = state[c].wrapping_add(state[d]); -/// state[b] = (state[b] ^ state[c]).rotate_right(12); -/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(y); -/// state[d] = (state[d] ^ state[a]).rotate_right(8); -/// state[c] = state[c].wrapping_add(state[d]); -/// state[b] = (state[b] ^ state[c]).rotate_right(7); -/// } -/// ``` -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct GOperation { - pub a_plus_b: AddOperation, - pub a_plus_b_plus_x: AddOperation, - pub d_xor_a: XorOperation, - // Rotate right by 16 bits by just shifting bytes. - pub c_plus_d: AddOperation, - pub b_xor_c: XorOperation, - pub b_xor_c_rotate_right_12: FixedRotateRightOperation, - pub a_plus_b_2: AddOperation, - pub a_plus_b_2_add_y: AddOperation, - // Rotate right by 8 bits by just shifting bytes. - pub d_xor_a_2: XorOperation, - pub c_plus_d_2: AddOperation, - pub b_xor_c_2: XorOperation, - pub b_xor_c_2_rotate_right_7: FixedRotateRightOperation, - /// `state[a]`, `state[b]`, `state[c]`, `state[d]` after all the steps. - pub result: [Word; 4], -} - -impl GOperation { - pub fn populate( - &mut self, - record: &mut ExecutionRecord, - shard: u32, - channel: u32, - input: [u32; 6], - ) -> [u32; 4] { - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - - // First 4 steps. - { - // a = a + b + x. - a = self.a_plus_b.populate(record, shard, channel, a, b); - a = self.a_plus_b_plus_x.populate(record, shard, channel, a, x); - - // d = (d ^ a).rotate_right(16). - d = self.d_xor_a.populate(record, shard, channel, d, a); - d = d.rotate_right(16); - - // c = c + d. - c = self.c_plus_d.populate(record, shard, channel, c, d); - - // b = (b ^ c).rotate_right(12). - b = self.b_xor_c.populate(record, shard, channel, b, c); - b = self - .b_xor_c_rotate_right_12 - .populate(record, shard, channel, b, 12); - } - - // Second 4 steps. - { - // a = a + b + y. - a = self.a_plus_b_2.populate(record, shard, channel, a, b); - a = self.a_plus_b_2_add_y.populate(record, shard, channel, a, y); - - // d = (d ^ a).rotate_right(8). - d = self.d_xor_a_2.populate(record, shard, channel, d, a); - d = d.rotate_right(8); - - // c = c + d. - c = self.c_plus_d_2.populate(record, shard, channel, c, d); - - // b = (b ^ c).rotate_right(7). - b = self.b_xor_c_2.populate(record, shard, channel, b, c); - b = self - .b_xor_c_2_rotate_right_7 - .populate(record, shard, channel, b, 7); - } - - let result = [a, b, c, d]; - assert_eq!(result, g_func(input)); - self.result = result.map(Word::from); - result - } - - pub fn eval( - builder: &mut AB, - input: [Word; 6], - cols: GOperation, - shard: AB::Var, - channel: impl Into + Clone, - is_real: AB::Var, - ) { - builder.assert_bool(is_real); - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - - // First 4 steps. - { - // a = a + b + x. - AddOperation::::eval( - builder, - a, - b, - cols.a_plus_b, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b.value; - AddOperation::::eval( - builder, - a, - x, - cols.a_plus_b_plus_x, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_plus_x.value; - - // d = (d ^ a).rotate_right(16). - XorOperation::::eval( - builder, - d, - a, - cols.d_xor_a, - shard, - channel.clone(), - is_real, - ); - d = cols.d_xor_a.value; - // Rotate right by 16 bits. - d = Word([d[2], d[3], d[0], d[1]]); - - // c = c + d. - AddOperation::::eval( - builder, - c, - d, - cols.c_plus_d, - shard, - channel.clone(), - is_real.into(), - ); - c = cols.c_plus_d.value; - - // b = (b ^ c).rotate_right(12). - XorOperation::::eval( - builder, - b, - c, - cols.b_xor_c, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c.value; - FixedRotateRightOperation::::eval( - builder, - b, - 12, - cols.b_xor_c_rotate_right_12, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_rotate_right_12.value; - } - - // Second 4 steps. - { - // a = a + b + y. - AddOperation::::eval( - builder, - a, - b, - cols.a_plus_b_2, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_2.value; - AddOperation::::eval( - builder, - a, - y, - cols.a_plus_b_2_add_y, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_2_add_y.value; - - // d = (d ^ a).rotate_right(8). - XorOperation::::eval( - builder, - d, - a, - cols.d_xor_a_2, - shard, - channel.clone(), - is_real, - ); - d = cols.d_xor_a_2.value; - // Rotate right by 8 bits. - d = Word([d[1], d[2], d[3], d[0]]); - - // c = c + d. - AddOperation::::eval( - builder, - c, - d, - cols.c_plus_d_2, - shard, - channel.clone(), - is_real.into(), - ); - c = cols.c_plus_d_2.value; - - // b = (b ^ c).rotate_right(7). - XorOperation::::eval( - builder, - b, - c, - cols.b_xor_c_2, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_2.value; - FixedRotateRightOperation::::eval( - builder, - b, - 7, - cols.b_xor_c_2_rotate_right_7, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_2_rotate_right_7.value; - } - - let results = [a, b, c, d]; - for i in 0..4 { - for j in 0..WORD_SIZE { - builder.assert_eq(cols.result[i][j], results[i][j]); - } - } - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs deleted file mode 100644 index a89b9bcc34..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/mod.rs +++ /dev/null @@ -1,179 +0,0 @@ -//! This module contains the implementation of the `blake3_compress_inner` precompile based on the -//! implementation of the `blake3` hash function in BLAKE3. -//! -//! Pseudo-code. -//! -//! state = [0u32; 16] -//! message = [0u32; 16] -//! -//! for round in 0..7 { -//! for operation in 0..8 { -//! // * Pick 4 indices a, b, c, d for the state, based on the operation index. -//! // * Pick 2 indices x, y for the message, based on both the round and the operation index. -//! // -//! // g takes those 6 values, and updates the 4 state values, at indices a, b, c, d. -//! // -//! // Each call of g becomes one row in the trace. -//! g(&mut state[a], &mut state[b], &mut state[c], &mut state[d], message[x], message[y]); -//! } -//! } -//! -//! Note that this precompile is only the blake3 compress inner function. The Blake3 compress -//! function has a series of 8 XOR operations after the compress inner function. -mod air; -mod columns; -mod execute; -mod g; -mod trace; -use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; - -use serde::{Deserialize, Serialize}; - -/// The number of `Word`s in the message of the compress inner operation. -pub(crate) const MSG_SIZE: usize = 16; - -/// The number of times we call `round` in the compress inner operation. -pub(crate) const ROUND_COUNT: usize = 7; - -/// The number of times we call `g` in the compress inner operation. -pub(crate) const OPERATION_COUNT: usize = 8; - -/// The number of `Word`s in the state that we pass to `g`. -pub(crate) const NUM_STATE_WORDS_PER_CALL: usize = 4; - -/// The number of `Word`s in the message that we pass to `g`. -pub(crate) const NUM_MSG_WORDS_PER_CALL: usize = 2; - -/// The number of `Word`s in the input of `g`. -pub(crate) const G_INPUT_SIZE: usize = NUM_MSG_WORDS_PER_CALL + NUM_STATE_WORDS_PER_CALL; - -/// 2-dimensional array specifying which message values `g` should access. Values at `(i, 2 * j)` -/// and `(i, 2 * j + 1)` are the indices of the message values that `g` should access in the `j`-th -/// call of the `i`-th round. -pub(crate) const MSG_SCHEDULE: [[usize; MSG_SIZE]; ROUND_COUNT] = [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], - [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], - [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], - [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], - [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], - [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], -]; - -/// The `i`-th row of `G_INDEX` is the indices used for the `i`-th call to `g`. -pub(crate) const G_INDEX: [[usize; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT] = [ - [0, 4, 8, 12], - [1, 5, 9, 13], - [2, 6, 10, 14], - [3, 7, 11, 15], - [0, 5, 10, 15], - [1, 6, 11, 12], - [2, 7, 8, 13], - [3, 4, 9, 14], -]; - -pub(crate) const fn g_func(input: [u32; 6]) -> [u32; 4] { - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - a = a.wrapping_add(b).wrapping_add(x); - d = (d ^ a).rotate_right(16); - c = c.wrapping_add(d); - b = (b ^ c).rotate_right(12); - a = a.wrapping_add(b).wrapping_add(y); - d = (d ^ a).rotate_right(8); - c = c.wrapping_add(d); - b = (b ^ c).rotate_right(7); - [a, b, c, d] -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Blake3CompressInnerEvent { - pub clk: u32, - pub shard: u32, - pub channel: u32, - pub state_ptr: u32, - pub message_ptr: u32, - pub message_reads: [[[MemoryReadRecord; NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], - pub state_writes: - [[[MemoryWriteRecord; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], -} - -pub struct Blake3CompressInnerChip {} - -impl Blake3CompressInnerChip { - pub const fn new() -> Self { - Self {} - } -} - -#[cfg(test)] -pub mod compress_tests { - use crate::runtime::Instruction; - use crate::runtime::Opcode; - use crate::runtime::Register; - use crate::runtime::SyscallCode; - use crate::Program; - - use super::MSG_SIZE; - - /// The number of `Word`s in the state of the compress inner operation. - const STATE_SIZE: usize = 16; - - pub fn blake3_compress_internal_program() -> Program { - let state_ptr = 100; - let msg_ptr = 500; - let mut instructions = vec![]; - - for i in 0..STATE_SIZE { - // Store 1000 + i in memory for the i-th word of the state. 1000 + i is an arbitrary - // number that is easy to spot while debugging. - instructions.extend(vec![ - Instruction::new(Opcode::ADD, 29, 0, 1000 + i as u32, false, true), - Instruction::new(Opcode::ADD, 30, 0, state_ptr + i as u32 * 4, false, true), - Instruction::new(Opcode::SW, 29, 30, 0, false, true), - ]); - } - for i in 0..MSG_SIZE { - // Store 2000 + i in memory for the i-th word of the message. 2000 + i is an arbitrary - // number that is easy to spot while debugging. - instructions.extend(vec![ - Instruction::new(Opcode::ADD, 29, 0, 2000 + i as u32, false, true), - Instruction::new(Opcode::ADD, 30, 0, msg_ptr + i as u32 * 4, false, true), - Instruction::new(Opcode::SW, 29, 30, 0, false, true), - ]); - } - instructions.extend(vec![ - Instruction::new( - Opcode::ADD, - 5, - 0, - SyscallCode::BLAKE3_COMPRESS_INNER as u32, - false, - true, - ), - Instruction::new(Opcode::ADD, Register::X10 as u32, 0, state_ptr, false, true), - Instruction::new(Opcode::ADD, Register::X11 as u32, 0, msg_ptr, false, true), - Instruction::new(Opcode::ECALL, 5, 10, 11, false, false), - ]); - Program::new(instructions, 0, 0) - } - - // Tests disabled because syscall is not enabled in default runtime/chip configs. - // #[test] - // fn prove_babybear() { - // setup_logger(); - // let program = blake3_compress_internal_program(); - // run_test(program).unwrap(); - // } - - // #[test] - // fn test_blake3_compress_inner_elf() { - // setup_logger(); - // let program = Program::from(BLAKE3_COMPRESS_ELF); - // run_test(program).unwrap(); - // } -} diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs deleted file mode 100644 index 14994cb031..0000000000 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::borrow::BorrowMut; - -use p3_field::PrimeField32; -use p3_matrix::dense::RowMajorMatrix; - -use super::columns::Blake3CompressInnerCols; -use super::{ - G_INDEX, G_INPUT_SIZE, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, - OPERATION_COUNT, -}; -use crate::air::MachineAir; -use crate::bytes::event::ByteRecord; -use crate::runtime::ExecutionRecord; -use crate::runtime::MemoryRecordEnum; -use crate::runtime::Program; -use crate::syscall::precompiles::blake3::compress::columns::NUM_BLAKE3_COMPRESS_INNER_COLS; -use crate::syscall::precompiles::blake3::{Blake3CompressInnerChip, ROUND_COUNT}; -use crate::utils::pad_rows; - -impl MachineAir for Blake3CompressInnerChip { - type Record = ExecutionRecord; - type Program = Program; - - fn name(&self) -> String { - "Blake3CompressInner".to_string() - } - - fn generate_trace( - &self, - input: &ExecutionRecord, - output: &mut ExecutionRecord, - ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - let mut new_byte_lookup_events = Vec::new(); - - for i in 0..input.blake3_compress_inner_events.len() { - let event = input.blake3_compress_inner_events[i].clone(); - let shard = event.shard; - let channel = event.channel; - let mut clk = event.clk; - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - let mut row = [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]; - let cols: &mut Blake3CompressInnerCols = row.as_mut_slice().borrow_mut(); - - // Assign basic values to the columns. - { - cols.shard = F::from_canonical_u32(event.shard); - cols.channel = F::from_canonical_u32(event.channel); - cols.clk = F::from_canonical_u32(clk); - - cols.round_index = F::from_canonical_u32(round as u32); - cols.is_round_index_n[round] = F::one(); - - cols.operation_index = F::from_canonical_u32(operation as u32); - cols.is_operation_index_n[operation] = F::one(); - - for i in 0..NUM_STATE_WORDS_PER_CALL { - cols.state_index[i] = F::from_canonical_usize(G_INDEX[operation][i]); - } - - for i in 0..NUM_MSG_WORDS_PER_CALL { - cols.msg_schedule[i] = - F::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]); - } - - if round == 0 && operation == 0 { - cols.ecall_receive = F::one(); - } - } - - // Memory columns. - { - cols.message_ptr = F::from_canonical_u32(event.message_ptr); - for i in 0..NUM_MSG_WORDS_PER_CALL { - cols.message_reads[i].populate( - channel, - event.message_reads[round][operation][i], - &mut new_byte_lookup_events, - ); - } - - cols.state_ptr = F::from_canonical_u32(event.state_ptr); - for i in 0..NUM_STATE_WORDS_PER_CALL { - cols.state_reads_writes[i].populate( - channel, - MemoryRecordEnum::Write(event.state_writes[round][operation][i]), - &mut new_byte_lookup_events, - ); - } - } - - // Apply the `g` operation. - { - let input: [u32; G_INPUT_SIZE] = [ - event.state_writes[round][operation][0].prev_value, - event.state_writes[round][operation][1].prev_value, - event.state_writes[round][operation][2].prev_value, - event.state_writes[round][operation][3].prev_value, - event.message_reads[round][operation][0].value, - event.message_reads[round][operation][1].value, - ]; - - cols.g.populate(output, shard, channel, input); - } - - clk += 1; - - cols.is_real = F::one(); - - rows.push(row); - } - } - } - - output.add_byte_lookup_events(new_byte_lookup_events); - - pad_rows(&mut rows, || [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]); - - // Convert the trace to a row major matrix. - RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_BLAKE3_COMPRESS_INNER_COLS, - ) - } - - fn included(&self, shard: &Self::Record) -> bool { - !shard.blake3_compress_inner_events.is_empty() - } -} diff --git a/core/src/syscall/precompiles/blake3/mod.rs b/core/src/syscall/precompiles/blake3/mod.rs deleted file mode 100644 index 8b286ad176..0000000000 --- a/core/src/syscall/precompiles/blake3/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod compress; - -pub use compress::*; diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index f4e15423a7..44d29edb83 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use num::BigUint; use num::Zero; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -54,6 +55,7 @@ pub struct EdAddAssignCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub p_ptr: T, pub q_ptr: T, pub p_access: [MemoryWriteCols; WORDS_CURVE_POINT], @@ -238,10 +240,19 @@ impl MachineAir for Ed }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_ED_ADD_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut EdAddAssignCols = + trace.values[i * NUM_ED_ADD_COLS..(i + 1) * NUM_ED_ADD_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -261,141 +272,150 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &EdAddAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &EdAddAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &EdAddAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - let x1 = limbs_from_prev_access(&row.p_access[0..8]); - let x2 = limbs_from_prev_access(&row.q_access[0..8]); - let y1 = limbs_from_prev_access(&row.p_access[8..16]); - let y2 = limbs_from_prev_access(&row.q_access[8..16]); + let x1 = limbs_from_prev_access(&local.p_access[0..8]); + let x2 = limbs_from_prev_access(&local.q_access[0..8]); + let y1 = limbs_from_prev_access(&local.p_access[8..16]); + let y2 = limbs_from_prev_access(&local.q_access[8..16]); // x3_numerator = x1 * y2 + x2 * y1. - row.x3_numerator.eval( + local.x3_numerator.eval( builder, &[x1, x2], &[y2, y1], - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // y3_numerator = y1 * y2 + x1 * x2. - row.y3_numerator.eval( + local.y3_numerator.eval( builder, &[y1, x1], &[y2, x2], - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // f = x1 * x2 * y1 * y2. - row.x1_mul_y1.eval( + local.x1_mul_y1.eval( builder, &x1, &y1, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x2_mul_y2.eval( + local.x2_mul_y2.eval( builder, &x2, &y2, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - let x1_mul_y1 = row.x1_mul_y1.result; - let x2_mul_y2 = row.x2_mul_y2.result; - row.f.eval( + let x1_mul_y1 = local.x1_mul_y1.result; + let x2_mul_y2 = local.x2_mul_y2.result; + local.f.eval( builder, &x1_mul_y1, &x2_mul_y2, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // d * f. - let f = row.f.result; + let f = local.f.result; let d_biguint = E::d_biguint(); let d_const = E::BaseField::to_limbs_field::(&d_biguint); - row.d_mul_f.eval( + local.d_mul_f.eval( builder, &f, &d_const, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - let d_mul_f = row.d_mul_f.result; + let d_mul_f = local.d_mul_f.result; // x3 = x3_numerator / (1 + d * f). - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.x3_numerator.result, + &local.x3_numerator.result, &d_mul_f, true, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // y3 = y3_numerator / (1 - d * f). - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.y3_numerator.result, + &local.y3_numerator.result, &d_mul_f, false, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // Constraint self.p_access.value = [self.x3_ins.result, self.y3_ins.result] // This is to ensure that p_access is updated with the new value. - let p_access_vec = value_as_limbs(&row.p_access); + let p_access_vec = value_as_limbs(&local.p_access); builder - .when(row.is_real) - .assert_all_eq(row.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); - builder.when(row.is_real).assert_all_eq( - row.y3_ins.result, + .when(local.is_real) + .assert_all_eq(local.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); + builder.when(local.is_real).assert_all_eq( + local.y3_ins.result, p_access_vec[NUM_LIMBS..NUM_LIMBS * 2].to_vec(), ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.q_ptr, - &row.q_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.q_ptr, + &local.q_access, + local.is_real, ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk + AB::F::from_canonical_u32(1), - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk + AB::F::from_canonical_u32(1), + local.p_ptr, + &local.p_access, + local.is_real, ); builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::ED_ADD.syscall_id()), - row.p_ptr, - row.q_ptr, - row.is_real, + local.p_ptr, + local.q_ptr, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index be62467c00..a0618137ce 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -53,6 +53,7 @@ use super::{WordsFieldElement, WORDS_FIELD_ELEMENT}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EdDecompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -78,6 +79,7 @@ pub struct EdDecompressCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub ptr: T, pub sign: T, pub x_access: GenericArray, WordsFieldElement>, @@ -104,6 +106,13 @@ impl EdDecompressCols { self.channel = F::from_canonical_u32(event.channel); self.clk = F::from_canonical_u32(event.clk); self.ptr = F::from_canonical_u32(event.ptr); + self.nonce = F::from_canonical_u32( + record + .nonce_lookup + .get(&event.lookup_id) + .copied() + .unwrap_or_default(), + ); self.sign = F::from_bool(event.sign); for i in 0..8 { self.x_access[i].populate( @@ -276,6 +285,7 @@ impl EdDecompressCols { self.shard, self.channel, self.clk, + self.nonce, AB::F::from_canonical_u32(SyscallCode::ED_DECOMPRESS.syscall_id()), self.ptr, self.sign, @@ -326,11 +336,13 @@ impl Syscall for EdDecompressChip { let x_memory_records_vec = rt.mw_slice(slice_ptr, &decompressed_x_words); let x_memory_records: [MemoryWriteRecord; 8] = x_memory_records_vec.try_into().unwrap(); + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut() .ed_decompress_events .push(EdDecompressEvent { + lookup_id, shard, channel, clk: start_clk, @@ -390,10 +402,20 @@ impl MachineAir for EdDecompressChip>(), NUM_ED_DECOMPRESS_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut EdDecompressCols = trace.values + [i * NUM_ED_DECOMPRESS_COLS..(i + 1) * NUM_ED_DECOMPRESS_COLS] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -413,9 +435,18 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &EdDecompressCols = (*row).borrow(); - row.eval::(builder); + let local = main.row_slice(0); + let local: &EdDecompressCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &EdDecompressCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + + local.eval::(builder); } } diff --git a/core/src/syscall/precompiles/keccak256/air.rs b/core/src/syscall/precompiles/keccak256/air.rs index 9e67c12490..1647616798 100644 --- a/core/src/syscall/precompiles/keccak256/air.rs +++ b/core/src/syscall/precompiles/keccak256/air.rs @@ -32,6 +32,12 @@ where let local: &KeccakMemCols = (*local).borrow(); let next: &KeccakMemCols = (*next).borrow(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let first_step = local.keccak.step_flags[0]; let final_step = local.keccak.step_flags[NUM_ROUNDS - 1]; let not_final_step = AB::Expr::one() - final_step; @@ -68,6 +74,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::KECCAK_PERMUTE.syscall_id()), local.state_addr, AB::Expr::zero(), @@ -79,6 +86,7 @@ where let mut transition_not_final_builder = transition_builder.when(not_final_step); transition_not_final_builder.assert_eq(local.shard, next.shard); transition_not_final_builder.assert_eq(local.clk, next.clk); + transition_not_final_builder.assert_eq(local.channel, next.channel); transition_not_final_builder.assert_eq(local.state_addr, next.state_addr); transition_not_final_builder.assert_eq(local.is_real, next.is_real); @@ -123,6 +131,16 @@ where } } + // Range check all the values in `state_mem` to be bytes. + for i in 0..STATE_NUM_WORDS { + builder.slice_range_check_u8( + &local.state_mem[i].value().0, + local.shard, + local.channel, + local.do_memory_check, + ); + } + let mut sub_builder = SubAirBuilder::::new(builder, 0..NUM_KECCAK_COLS); diff --git a/core/src/syscall/precompiles/keccak256/columns.rs b/core/src/syscall/precompiles/keccak256/columns.rs index a3e2dd3044..ad3aa5f099 100644 --- a/core/src/syscall/precompiles/keccak256/columns.rs +++ b/core/src/syscall/precompiles/keccak256/columns.rs @@ -20,6 +20,7 @@ pub(crate) struct KeccakMemCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub state_addr: T, /// Memory columns for the state. diff --git a/core/src/syscall/precompiles/keccak256/execute.rs b/core/src/syscall/precompiles/keccak256/execute.rs index d6c306c45f..eecc747bed 100644 --- a/core/src/syscall/precompiles/keccak256/execute.rs +++ b/core/src/syscall/precompiles/keccak256/execute.rs @@ -99,9 +99,11 @@ impl Syscall for KeccakPermuteChip { // Push the Keccak permute event. let shard = rt.current_shard(); let channel = rt.current_channel(); + let lookup_id = rt.syscall_lookup_id; rt.record_mut() .keccak_permute_events .push(KeccakPermuteEvent { + lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/keccak256/mod.rs b/core/src/syscall/precompiles/keccak256/mod.rs index 4110707a83..2b95b8b400 100644 --- a/core/src/syscall/precompiles/keccak256/mod.rs +++ b/core/src/syscall/precompiles/keccak256/mod.rs @@ -15,6 +15,7 @@ const STATE_NUM_WORDS: usize = STATE_SIZE * 2; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeccakPermuteEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index 01b07fb743..e4700fe97a 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -83,8 +83,12 @@ impl MachineAir for KeccakPermuteChip { *read_record, &mut new_byte_lookup_events, ); + new_byte_lookup_events.add_u8_range_checks( + shard, + channel, + &read_record.value.to_le_bytes(), + ); } - cols.do_memory_check = F::one(); cols.receive_ecall = F::one(); } @@ -99,8 +103,12 @@ impl MachineAir for KeccakPermuteChip { *write_record, &mut new_byte_lookup_events, ); + new_byte_lookup_events.add_u8_range_checks( + shard, + channel, + &write_record.value.to_le_bytes(), + ); } - cols.do_memory_check = F::one(); } @@ -147,10 +155,19 @@ impl MachineAir for KeccakPermuteChip { } // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_KECCAK_MEM_COLS, - ) + ); + + // Write the nonce to the trace. + for i in 0..trace.height() { + let cols: &mut KeccakMemCols = + trace.values[i * NUM_KECCAK_MEM_COLS..(i + 1) * NUM_KECCAK_MEM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/mod.rs b/core/src/syscall/precompiles/mod.rs index 7b107e2d5b..5f1bc43e60 100644 --- a/core/src/syscall/precompiles/mod.rs +++ b/core/src/syscall/precompiles/mod.rs @@ -1,4 +1,3 @@ -pub mod blake3; pub mod edwards; pub mod keccak256; pub mod sha256; @@ -20,6 +19,7 @@ use serde::{Deserialize, Serialize}; /// Elliptic curve add event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECAddEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -68,6 +68,7 @@ pub fn create_ec_add_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); ECAddEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -83,6 +84,7 @@ pub fn create_ec_add_event( /// Elliptic curve double event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDoubleEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -119,6 +121,7 @@ pub fn create_ec_double_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); ECDoubleEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -131,6 +134,7 @@ pub fn create_ec_double_event( /// Elliptic curve point decompress event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDecompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -176,6 +180,7 @@ pub fn create_ec_decompress_event( let y_memory_records = rt.mw_slice(slice_ptr, &y_words); ECDecompressEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/air.rs b/core/src/syscall/precompiles/sha256/compress/air.rs index 2f4bd5000a..7a28a456b6 100644 --- a/core/src/syscall/precompiles/sha256/compress/air.rs +++ b/core/src/syscall/precompiles/sha256/compress/air.rs @@ -30,6 +30,12 @@ where let local: &ShaCompressCols = (*local).borrow(); let next: &ShaCompressCols = (*next).borrow(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + self.eval_control_flow_flags(builder, local, next); self.eval_memory(builder, local); @@ -46,6 +52,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_COMPRESS.syscall_id()), local.w_ptr, local.h_ptr, @@ -71,19 +78,15 @@ impl ShaCompressChip { for i in 0..8 { octet_sum += local.octet[i].into(); } - builder.when(local.is_real).assert_one(octet_sum); + builder.assert_one(octet_sum); // Verify that the first row's octet value is correct. - builder - .when_first_row() - .when(local.is_real) - .assert_one(local.octet[0]); + builder.when_first_row().assert_one(local.octet[0]); // Verify correct transition for octet column. for i in 0..8 { builder .when_transition() - .when(next.is_real) .when(local.octet[i]) .assert_one(next.octet[(i + 1) % 8]) } @@ -98,19 +101,15 @@ impl ShaCompressChip { for i in 0..10 { octet_num_sum += local.octet_num[i].into(); } - builder.when(local.is_real).assert_one(octet_num_sum); + builder.assert_one(octet_num_sum); // The first row should have octet_num[0] = 1 if it's real. - builder - .when_first_row() - .when(local.is_real) - .assert_one(local.octet_num[0]); + builder.when_first_row().assert_one(local.octet_num[0]); // If current row is not last of an octet and next row is real, octet_num should be the same. for i in 0..10 { builder .when_transition() - .when(next.is_real) .when_not(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[i]); } @@ -119,7 +118,6 @@ impl ShaCompressChip { for i in 0..10 { builder .when_transition() - .when(next.is_real) .when(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[(i + 1) % 10]); } @@ -146,19 +144,26 @@ impl ShaCompressChip { .assert_word_eq(*var, *local.mem.value()); } + // Assert that the is_initialize flag is correct. + builder.assert_eq(local.is_initialize, local.octet_num[0] * local.is_real); + // Assert that the is_compression flag is correct. builder.assert_eq( local.is_compression, - local.octet_num[1] + (local.octet_num[1] + local.octet_num[2] + local.octet_num[3] + local.octet_num[4] + local.octet_num[5] + local.octet_num[6] + local.octet_num[7] - + local.octet_num[8], + + local.octet_num[8]) + * local.is_real, ); + // Assert that the is_finalize flag is correct. + builder.assert_eq(local.is_finalize, local.octet_num[9] * local.is_real); + builder.assert_eq( local.is_last_row.into(), local.octet[7] * local.octet_num[9], @@ -175,6 +180,10 @@ impl ShaCompressChip { .when(local.is_real) .when_not(local.is_last_row) .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when_not(local.is_last_row) + .assert_eq(local.channel, next.channel); builder .when_transition() .when(local.is_real) @@ -186,6 +195,9 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_eq(local.h_ptr, next.h_ptr); + // Assert that is_real is a bool. + builder.assert_bool(local.is_real); + // If this row is real and not the last cycle, then next row should also be real. builder .when_transition() @@ -193,6 +205,12 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_one(next.is_real); + // Once the is_real flag is changed to false, it should not be changed back. + builder + .when_transition() + .when_not(local.is_real) + .assert_zero(next.is_real); + // Assert that the table ends in nonreal columns. Since each compress ecall is 80 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. builder.when_last_row().assert_zero(local.is_real); @@ -200,15 +218,13 @@ impl ShaCompressChip { /// Constrains that memory address is correct and that memory is correctly written/read. fn eval_memory(&self, builder: &mut AB, local: &ShaCompressCols) { - let is_initialize = local.octet_num[0]; - let is_finalize = local.octet_num[9]; builder.eval_memory_access( local.shard, local.channel, - local.clk + is_finalize, + local.clk + local.is_finalize, local.mem_addr, &local.mem, - is_initialize + local.is_compression + is_finalize, + local.is_initialize + local.is_compression + local.is_finalize, ); // Calculate the current cycle_num. @@ -224,7 +240,7 @@ impl ShaCompressChip { } // Verify correct mem address for initialize phase - builder.when(is_initialize).assert_eq( + builder.when(local.is_initialize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -239,7 +255,7 @@ impl ShaCompressChip { ); // Verify correct mem address for finalize phase - builder.when(is_finalize).assert_eq( + builder.when(local.is_finalize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -251,11 +267,11 @@ impl ShaCompressChip { ]; for (i, var) in vars.iter().enumerate() { builder - .when(is_initialize) + .when(local.is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.prev_value()); builder - .when(is_initialize) + .when(local.is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.value()); } @@ -267,7 +283,7 @@ impl ShaCompressChip { // In the finalize phase, verify that the correct value is written to memory. builder - .when(is_finalize) + .when(local.is_finalize) .assert_word_eq(*local.mem.value(), local.finalize_add.value); } @@ -579,7 +595,6 @@ impl ShaCompressChip { builder: &mut AB, local: &ShaCompressCols, ) { - let is_finalize = local.octet_num[9]; // In the finalize phase, need to execute h[0] + a, h[1] + b, ..., h[7] + h, for each of the // phase's 8 rows. // We can get the needed operand (a,b,c,...,h) by doing an inner product between octet and @@ -596,7 +611,7 @@ impl ShaCompressChip { } builder - .when(is_finalize) + .when(local.is_finalize) .assert_word_eq(filtered_operand, local.finalized_operand.map(|x| x.into())); // finalize_add.result = h[i] + finalized_operand @@ -607,7 +622,7 @@ impl ShaCompressChip { local.finalize_add, local.shard, local.channel, - is_finalize.into(), + local.is_finalize.into(), ); // Memory write is constrained in constrain_memory. diff --git a/core/src/syscall/precompiles/sha256/compress/columns.rs b/core/src/syscall/precompiles/sha256/compress/columns.rs index 94a200aedd..0fd7a7fbf4 100644 --- a/core/src/syscall/precompiles/sha256/compress/columns.rs +++ b/core/src/syscall/precompiles/sha256/compress/columns.rs @@ -26,6 +26,7 @@ pub struct ShaCompressCols { /// Inputs. pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub w_ptr: T, pub h_ptr: T, @@ -102,7 +103,9 @@ pub struct ShaCompressCols { pub finalized_operand: Word, pub finalize_add: AddOperation, + pub is_initialize: T, pub is_compression: T, + pub is_finalize: T, pub is_last_row: T, pub is_real: T, diff --git a/core/src/syscall/precompiles/sha256/compress/execute.rs b/core/src/syscall/precompiles/sha256/compress/execute.rs index a019abbd4c..5ed33dd2b7 100644 --- a/core/src/syscall/precompiles/sha256/compress/execute.rs +++ b/core/src/syscall/precompiles/sha256/compress/execute.rs @@ -76,9 +76,11 @@ impl Syscall for ShaCompressChip { } // Push the SHA extend event. + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_compress_events.push(ShaCompressEvent { + lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/mod.rs b/core/src/syscall/precompiles/sha256/compress/mod.rs index fd6c50f0fc..47401a25bc 100644 --- a/core/src/syscall/precompiles/sha256/compress/mod.rs +++ b/core/src/syscall/precompiles/sha256/compress/mod.rs @@ -20,6 +20,7 @@ pub const SHA_COMPRESS_K: [u32; 64] = [ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaCompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index bd0b8f8177..6cd524fbd6 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -2,6 +2,7 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; use super::{ columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS}, @@ -53,6 +54,7 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); + cols.is_initialize = F::one(); cols.mem.populate_read( channel, @@ -207,6 +209,7 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); + cols.is_finalize = F::one(); cols.finalize_add .populate(output, shard, channel, og_h[j], event.h[j]); @@ -249,13 +252,48 @@ impl MachineAir for ShaCompressChip { output.add_byte_lookup_events(new_byte_lookup_events); + let num_real_rows = rows.len(); + pad_rows(&mut rows, || [F::zero(); NUM_SHA_COMPRESS_COLS]); + // Set the octet_num and octect columns for the padded rows. + let mut octet_num = 0; + let mut octet = 0; + for row in rows[num_real_rows..].iter_mut() { + let cols: &mut ShaCompressCols = row.as_mut_slice().borrow_mut(); + cols.octet_num[octet_num] = F::one(); + cols.octet[octet] = F::one(); + + // If in the compression phase, set the k value. + if octet_num != 0 && octet_num != 9 { + let compression_idx = octet_num - 1; + let k_idx = compression_idx * 8 + octet; + cols.k = Word::from(SHA_COMPRESS_K[k_idx]); + } + + octet = (octet + 1) % 8; + if octet == 0 { + octet_num = (octet_num + 1) % 10; + } + + cols.is_last_row = cols.octet[7] * cols.octet_num[9]; + } + // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_COMPRESS_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShaCompressCols = trace.values + [i * NUM_SHA_COMPRESS_COLS..(i + 1) * NUM_SHA_COMPRESS_COLS] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/sha256/extend/air.rs b/core/src/syscall/precompiles/sha256/extend/air.rs index 9da6048048..69f38c3557 100644 --- a/core/src/syscall/precompiles/sha256/extend/air.rs +++ b/core/src/syscall/precompiles/sha256/extend/air.rs @@ -27,6 +27,13 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &ShaExtendCols = (*local).borrow(); let next: &ShaExtendCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let i_start = AB::F::from_canonical_u32(16); let nb_bytes_in_word = AB::F::from_canonical_u32(4); @@ -42,6 +49,10 @@ where .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when_not(local.cycle_16_end.result * local.cycle_48[2]) + .assert_eq(local.channel, next.channel); builder .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) @@ -214,22 +225,28 @@ where local.is_real, ); + builder.assert_word_eq(*local.w_i.value(), local.s2.value); + // Receive syscall event in first row of 48-cycle. builder.receive_syscall( local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_EXTEND.syscall_id()), local.w_ptr, AB::Expr::zero(), local.cycle_48_start, ); - // If this row is real and not the last cycle, then next row should also be real. + // Assert that is_real is a bool. + builder.assert_bool(local.is_real); + + // Ensure that all rows in a 48 row cycle has the same `is_real` values. builder .when_transition() - .when(local.is_real - local.cycle_48_end) - .assert_one(next.is_real); + .when_not(local.cycle_48_end) + .assert_eq(local.is_real, next.is_real); // Assert that the table ends in nonreal columns. Since each extend ecall is 48 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. diff --git a/core/src/syscall/precompiles/sha256/extend/columns.rs b/core/src/syscall/precompiles/sha256/extend/columns.rs index 5eb99e1f4d..0855b44139 100644 --- a/core/src/syscall/precompiles/sha256/extend/columns.rs +++ b/core/src/syscall/precompiles/sha256/extend/columns.rs @@ -18,6 +18,7 @@ pub struct ShaExtendCols { /// Inputs. pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub w_ptr: T, @@ -36,8 +37,9 @@ pub struct ShaExtendCols { /// Flags for when in the first, second, or third 16-row cycle. pub cycle_48: [T; 3], - /// Whether the current row is the first of a 48-row cycle. + /// Whether the current row is the first of a 48-row cycle and is real. pub cycle_48_start: T, + /// Whether the current row is the end of a 48-row cycle and is real. pub cycle_48_end: T, /// Inputs to `s0`. diff --git a/core/src/syscall/precompiles/sha256/extend/execute.rs b/core/src/syscall/precompiles/sha256/extend/execute.rs index bd163c26c9..d9b1a70e09 100644 --- a/core/src/syscall/precompiles/sha256/extend/execute.rs +++ b/core/src/syscall/precompiles/sha256/extend/execute.rs @@ -60,9 +60,11 @@ impl Syscall for ShaExtendChip { } // Push the SHA extend event. + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_extend_events.push(ShaExtendEvent { + lookup_id, shard, channel, clk: clk_init, diff --git a/core/src/syscall/precompiles/sha256/extend/flags.rs b/core/src/syscall/precompiles/sha256/extend/flags.rs index 2f97dc92fc..a06f117e3d 100644 --- a/core/src/syscall/precompiles/sha256/extend/flags.rs +++ b/core/src/syscall/precompiles/sha256/extend/flags.rs @@ -7,6 +7,7 @@ use p3_field::PrimeField32; use p3_field::TwoAdicField; use p3_matrix::Matrix; +use crate::air::BaseAirBuilder; use crate::air::SP1AirBuilder; use crate::operations::IsZeroOperation; @@ -70,7 +71,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::from(g), local.cycle_16_start, - local.is_real.into(), + one.clone(), ); // Constrain `cycle_16_end.result` to be `cycle_16 - 1 == 0`. Intuitively g^16 is 1. @@ -78,7 +79,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::one(), local.cycle_16_end, - local.is_real.into(), + one.clone(), ); // Constrain `cycle_48` to be [1, 0, 0] in the first row. @@ -123,10 +124,10 @@ impl ShaExtendChip { .when(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(next.i, AB::F::from_canonical_u32(16)); - // When it's not the end of a 16-cycle, the next `i` must be the current plus one. + // When it's not the end of a 48-cycle, the next `i` must be the current plus one. builder .when_transition() - .when(one.clone() - local.cycle_16_end.result) + .when_not(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(local.i + one.clone(), next.i); } } diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index 4caff508b9..7868cabd88 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaExtendEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index 2a976ef0d6..2dcf882260 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -1,7 +1,7 @@ -use std::borrow::BorrowMut; - use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use std::borrow::BorrowMut; use crate::{ air::MachineAir, @@ -156,10 +156,19 @@ impl MachineAir for ShaExtendChip { } // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_EXTEND_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShaExtendCols = + trace.values[i * NUM_SHA_EXTEND_COLS..(i + 1) * NUM_SHA_EXTEND_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/uint256/air.rs b/core/src/syscall/precompiles/uint256/air.rs index 498ac78c6c..dd8cce29e4 100644 --- a/core/src/syscall/precompiles/uint256/air.rs +++ b/core/src/syscall/precompiles/uint256/air.rs @@ -17,6 +17,7 @@ use crate::utils::{ use generic_array::GenericArray; use num::Zero; use num::{BigUint, One}; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -33,6 +34,7 @@ const NUM_COLS: usize = size_of::>(); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Uint256MulEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -71,6 +73,9 @@ pub struct Uint256MulCols { /// The clock cycle of the syscall. pub clk: T, + /// The none of the operation. + pub nonce: T, + /// The pointer to the first input. pub x_ptr: T, @@ -201,7 +206,17 @@ impl MachineAir for Uint256MulChip { }); // Convert the trace to a row major matrix. - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS) + let mut trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut Uint256MulCols = + trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -257,10 +272,12 @@ impl Syscall for Uint256MulChip { // Write the result to x and keep track of the memory records. let x_memory_records = rt.mw_slice(x_ptr, &result); + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); let clk = rt.clk; rt.record_mut().uint256_mul_events.push(Uint256MulEvent { + lookup_id, shard, channel, clk, @@ -293,6 +310,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &Uint256MulCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &Uint256MulCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // We are computing (x * y) % modulus. The value of x is stored in the "prev_value" of // the x_memory, since we write to it later. @@ -368,6 +393,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()), local.x_ptr, local.y_ptr, diff --git a/core/src/syscall/precompiles/uint256/mod.rs b/core/src/syscall/precompiles/uint256/mod.rs index 9460a892da..0b35b0056f 100644 --- a/core/src/syscall/precompiles/uint256/mod.rs +++ b/core/src/syscall/precompiles/uint256/mod.rs @@ -13,7 +13,7 @@ mod tests { self, ec::{uint256::U256Field, utils::biguint_from_limbs}, run_test_io, - tests::{UINT256_DIV_ELF, UINT256_MUL_ELF}, + tests::UINT256_MUL_ELF, }, }; @@ -24,13 +24,6 @@ mod tests { run_test_io(program, SP1Stdin::new()).unwrap(); } - #[test] - fn test_uint256_div() { - utils::setup_logger(); - let program = Program::from(UINT256_DIV_ELF); - run_test_io(program, SP1Stdin::new()).unwrap(); - } - #[test] fn test_uint256_modulus() { assert_eq!(biguint_from_limbs(U256Field::MODULUS), U256Field::modulus()); diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index adbab629f0..2eef29c6c7 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -52,6 +52,7 @@ pub struct WeierstrassAddAssignCols { pub is_real: T, pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub p_ptr: T, pub q_ptr: T, @@ -302,10 +303,21 @@ impl MachineAir }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_add_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassAddAssignCols = trace.values[i + * num_weierstrass_add_cols::() + ..(i + 1) * num_weierstrass_add_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -331,117 +343,125 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassAddAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassAddAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassAddAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_words_field_element = ::Limbs::USIZE / 4; - let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); - let q_x = limbs_from_prev_access(&row.q_access[0..num_words_field_element]); - let q_y = limbs_from_prev_access(&row.q_access[num_words_field_element..]); + let q_x = limbs_from_prev_access(&local.q_access[0..num_words_field_element]); + let q_y = limbs_from_prev_access(&local.q_access[num_words_field_element..]); // slope = (q.y - p.y) / (q.x - p.x). let slope = { - row.slope_numerator.eval( + local.slope_numerator.eval( builder, &q_y, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_denominator.eval( + local.slope_denominator.eval( builder, &q_x, &p_x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope.eval( + local.slope.eval( builder, - &row.slope_numerator.result, - &row.slope_denominator.result, + &local.slope_numerator.result, + &local.slope_denominator.result, FieldOperation::Div, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.slope.result + &local.slope.result }; // x = slope * slope - self.x - other.x. let x = { - row.slope_squared.eval( + local.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_plus_q_x.eval( + local.p_x_plus_q_x.eval( builder, &p_x, &q_x, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.slope_squared.result, - &row.p_x_plus_q_x.result, + &local.slope_squared.result, + &local.p_x_plus_q_x.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.x3_ins.result + &local.x3_ins.result }; // y = slope * (p.x - x_3n) - q.y. { - row.p_x_minus_x.eval( + local.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_times_p_x_minus_x.eval( + local.slope_times_p_x_minus_x.eval( builder, slope, - &row.p_x_minus_x.result, + &local.p_x_minus_x.result, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.slope_times_p_x_minus_x.result, + &local.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); } @@ -449,29 +469,29 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(row.is_real) - .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); - builder.when(row.is_real).assert_eq( - row.y3_ins.result[i], - row.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(local.is_real) + .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); + builder.when(local.is_real).assert_eq( + local.y3_ins.result[i], + local.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.q_ptr, - &row.q_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.q_ptr, + &local.q_access, + local.is_real, ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. + local.p_ptr, + &local.p_access, + local.is_real, ); // Fetch the syscall id for the curve type. @@ -487,13 +507,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id_felt, - row.p_ptr, - row.q_ptr, - row.is_real, + local.p_ptr, + local.q_ptr, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs index bd38edea8e..62958e86ca 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs @@ -54,6 +54,7 @@ pub struct WeierstrassDecompressCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub ptr: T, pub is_odd: T, pub x_access: GenericArray, P::WordsFieldElement>, @@ -222,10 +223,21 @@ impl MachineAir row }); - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_decompress_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassDecompressCols = trace.values[i + * num_weierstrass_decompress_cols::() + ..(i + 1) * num_weierstrass_decompress_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -250,99 +262,108 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassDecompressCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassDecompressCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassDecompressCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_limbs = ::Limbs::USIZE; let num_words_field_element = num_limbs / 4; - builder.assert_bool(row.is_odd); + builder.assert_bool(local.is_odd); let x: Limbs::Limbs> = - limbs_from_prev_access(&row.x_access); - row.range_x - .eval(builder, &x, row.shard, row.channel, row.is_real); - row.x_2.eval( + limbs_from_prev_access(&local.x_access); + local + .range_x + .eval(builder, &x, local.shard, local.channel, local.is_real); + local.x_2.eval( builder, &x, &x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x_3.eval( + local.x_3.eval( builder, - &row.x_2.result, + &local.x_2.result, &x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); let b = E::b_int(); let b_const = E::BaseField::to_limbs_field::(&b); - row.x_3_plus_b.eval( + local.x_3_plus_b.eval( builder, - &row.x_3.result, + &local.x_3.result, &b_const, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.neg_y.eval( + local.neg_y.eval( builder, &[AB::Expr::zero()].iter(), - &row.y.multiplication.result, + &local.y.multiplication.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // Interpret the lowest bit of Y as whether it is odd or not. - let y_is_odd = row.y.lsb; + let y_is_odd = local.y.lsb; - row.y.eval( + local.y.eval( builder, - &row.x_3_plus_b.result, - row.y.lsb, - row.shard, - row.channel, - row.is_real, + &local.x_3_plus_b.result, + local.y.lsb, + local.shard, + local.channel, + local.is_real, ); let y_limbs: Limbs::Limbs> = - limbs_from_access(&row.y_access); + limbs_from_access(&local.y_access); builder - .when(row.is_real) - .when_ne(y_is_odd, AB::Expr::one() - row.is_odd) - .assert_all_eq(row.y.multiplication.result, y_limbs); + .when(local.is_real) + .when_ne(y_is_odd, AB::Expr::one() - local.is_odd) + .assert_all_eq(local.y.multiplication.result, y_limbs); builder - .when(row.is_real) - .when_ne(y_is_odd, row.is_odd) - .assert_all_eq(row.neg_y.result, y_limbs); + .when(local.is_real) + .when_ne(y_is_odd, local.is_odd) + .assert_all_eq(local.neg_y.result, y_limbs); for i in 0..num_words_field_element { builder.eval_memory_access( - row.shard, - row.channel, - row.clk, - row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), - &row.x_access[i], - row.is_real, + local.shard, + local.channel, + local.clk, + local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), + &local.x_access[i], + local.is_real, ); } for i in 0..num_words_field_element { builder.eval_memory_access( - row.shard, - row.channel, - row.clk, - row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), - &row.y_access[i], - row.is_real, + local.shard, + local.channel, + local.clk, + local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), + &local.y_access[i], + local.is_real, ); } @@ -357,13 +378,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id, - row.ptr, - row.is_odd, - row.is_real, + local.ptr, + local.is_odd, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index 50bb0a4332..9221d680f1 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -53,6 +53,7 @@ pub struct WeierstrassDoubleAssignCols { pub is_real: T, pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub p_ptr: T, pub p_access: GenericArray, P::WordsCurvePoint>, @@ -317,10 +318,21 @@ impl MachineAir }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_double_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassDoubleAssignCols = trace.values[i + * num_weierstrass_double_cols::() + ..(i + 1) * num_weierstrass_double_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -346,136 +358,143 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassDoubleAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassDoubleAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassDoubleAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_words_field_element = E::BaseField::NB_LIMBS / 4; - let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); - // a in the Weierstrass form: y^2 = x^3 + a * x + b. - // TODO: U32 can't be hardcoded here? + // `a` in the Weierstrass form: y^2 = x^3 + a * x + b. let a = E::BaseField::to_limbs_field::(&E::a_int()); // slope = slope_numerator / slope_denominator. let slope = { // slope_numerator = a + (p.x * p.x) * 3. { - row.p_x_squared.eval( + local.p_x_squared.eval( builder, &p_x, &p_x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_squared_times_3.eval( + local.p_x_squared_times_3.eval( builder, - &row.p_x_squared.result, + &local.p_x_squared.result, &E::BaseField::to_limbs_field::(&BigUint::from(3u32)), FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_numerator.eval( + local.slope_numerator.eval( builder, &a, - &row.p_x_squared_times_3.result, + &local.p_x_squared_times_3.result, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); }; // slope_denominator = 2 * y. - row.slope_denominator.eval( + local.slope_denominator.eval( builder, &E::BaseField::to_limbs_field::(&BigUint::from(2u32)), &p_y, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope.eval( + local.slope.eval( builder, - &row.slope_numerator.result, - &row.slope_denominator.result, + &local.slope_numerator.result, + &local.slope_denominator.result, FieldOperation::Div, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.slope.result + &local.slope.result }; // x = slope * slope - (p.x + p.x). let x = { - row.slope_squared.eval( + local.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_plus_p_x.eval( + local.p_x_plus_p_x.eval( builder, &p_x, &p_x, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.slope_squared.result, - &row.p_x_plus_p_x.result, + &local.slope_squared.result, + &local.p_x_plus_p_x.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.x3_ins.result + &local.x3_ins.result }; // y = slope * (p.x - x) - p.y. { - row.p_x_minus_x.eval( + local.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_times_p_x_minus_x.eval( + local.slope_times_p_x_minus_x.eval( builder, slope, - &row.p_x_minus_x.result, + &local.p_x_minus_x.result, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.slope_times_p_x_minus_x.result, + &local.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); } @@ -483,21 +502,21 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(row.is_real) - .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); - builder.when(row.is_real).assert_eq( - row.y3_ins.result[i], - row.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(local.is_real) + .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); + builder.when(local.is_real).assert_eq( + local.y3_ins.result[i], + local.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.p_ptr, + &local.p_access, + local.is_real, ); // Fetch the syscall id for the curve type. @@ -513,13 +532,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id_felt, - row.p_ptr, + local.p_ptr, AB::Expr::zero(), - row.is_real, + local.is_real, ); } } diff --git a/core/src/syscall/verify.rs b/core/src/syscall/verify.rs index e40639aeba..1909da9941 100644 --- a/core/src/syscall/verify.rs +++ b/core/src/syscall/verify.rs @@ -1,8 +1,6 @@ -use crate::{ - runtime::{Syscall, SyscallContext}, - stark::{RiscvAir, StarkGenericConfig}, - utils::BabyBearPoseidon2Inner, -}; +use core::panic; + +use crate::runtime::{Syscall, SyscallContext}; /// Verifies an SP1 recursive verifier proof. Note that this syscall only verifies the proof during /// runtime. The actual constraint-level verification is deferred to the recursive layer, where @@ -16,7 +14,6 @@ impl SyscallVerifySP1Proof { } impl Syscall for SyscallVerifySP1Proof { - #[allow(unused_variables, unused_mut)] fn execute(&self, ctx: &mut SyscallContext, vkey_ptr: u32, pv_digest_ptr: u32) -> Option { let rt = &mut ctx.rt; @@ -33,22 +30,26 @@ impl Syscall for SyscallVerifySP1Proof { .map(|i| rt.word(pv_digest_ptr + i * 4)) .collect::>(); - let (proof, proof_vk) = &rt.state.proof_stream[rt.state.proof_stream_ptr]; + let proof_index = rt.state.proof_stream_ptr; + if proof_index >= rt.state.proof_stream.len() { + panic!("Not enough proofs were written to the runtime."); + } + let (proof, proof_vk) = &rt.state.proof_stream[proof_index].clone(); rt.state.proof_stream_ptr += 1; - let config = BabyBearPoseidon2Inner::new(); - let mut challenger = config.challenger(); - // TODO: need to use RecursionAir here - let machine = RiscvAir::machine(config); - - // TODO: Need to import PublicValues from recursion. - // Assert the commit in vkey from runtime inputs matches the one from syscall. - // Assert that the public values digest from runtime inputs matches the one from syscall. - - // TODO: Verify proof - // machine - // .verify(proof_vk, proof, &mut challenger) - // .expect("proof verification failed"); + let vkey_bytes: [u32; 8] = vkey.try_into().unwrap(); + let pv_digest_bytes: [u32; 8] = pv_digest.try_into().unwrap(); + + ctx.rt + .subproof_verifier + .verify_deferred_proof(proof, proof_vk, vkey_bytes, pv_digest_bytes) + .unwrap_or_else(|e| { + panic!( + "Failed to verify proof {proof_index} with digest {}: {}", + hex::encode(bytemuck::cast_slice(&pv_digest_bytes)), + e + ) + }); None } diff --git a/core/src/syscall/write.rs b/core/src/syscall/write.rs index 2953cd33bd..d8956844d8 100644 --- a/core/src/syscall/write.rs +++ b/core/src/syscall/write.rs @@ -16,66 +16,66 @@ impl Syscall for SyscallWrite { let a2 = Register::X12; let rt = &mut ctx.rt; let fd = arg1; - if fd == 1 || fd == 2 || fd == 3 || fd == 4 { - let write_buf = arg2; - let nbytes = rt.register(a2); - // Read nbytes from memory starting at write_buf. - let bytes = (0..nbytes) - .map(|i| rt.byte(write_buf + i)) - .collect::>(); - let slice = bytes.as_slice(); - if fd == 1 { - let s = core::str::from_utf8(slice).unwrap(); - if s.contains("cycle-tracker-start:") { - let fn_name = s - .split("cycle-tracker-start:") - .last() - .unwrap() - .trim_end() - .trim_start(); - let depth = rt.cycle_tracker.len() as u32; - rt.cycle_tracker - .insert(fn_name.to_string(), (rt.state.global_clk, depth)); - let padding = (0..depth).map(|_| "│ ").collect::(); - log::debug!("{}┌╴{}", padding, fn_name); - } else if s.contains("cycle-tracker-end:") { - let fn_name = s - .split("cycle-tracker-end:") - .last() - .unwrap() - .trim_end() - .trim_start(); - let (start, depth) = rt.cycle_tracker.remove(fn_name).unwrap_or((0, 0)); - // Leftpad by 2 spaces for each depth. - let padding = (0..depth).map(|_| "│ ").collect::(); - log::info!( - "{}└╴{} cycles", - padding, - num_to_comma_separated(rt.state.global_clk - start as u64) - ); - } else { - let flush_s = update_io_buf(ctx, fd, s); - if !flush_s.is_empty() { - flush_s - .into_iter() - .for_each(|line| println!("stdout: {}", line)); - } - } - } else if fd == 2 { - let s = core::str::from_utf8(slice).unwrap(); + let write_buf = arg2; + let nbytes = rt.register(a2); + // Read nbytes from memory starting at write_buf. + let bytes = (0..nbytes) + .map(|i| rt.byte(write_buf + i)) + .collect::>(); + let slice = bytes.as_slice(); + if fd == 1 { + let s = core::str::from_utf8(slice).unwrap(); + if s.contains("cycle-tracker-start:") { + let fn_name = s + .split("cycle-tracker-start:") + .last() + .unwrap() + .trim_end() + .trim_start(); + let depth = rt.cycle_tracker.len() as u32; + rt.cycle_tracker + .insert(fn_name.to_string(), (rt.state.global_clk, depth)); + let padding = (0..depth).map(|_| "│ ").collect::(); + log::debug!("{}┌╴{}", padding, fn_name); + } else if s.contains("cycle-tracker-end:") { + let fn_name = s + .split("cycle-tracker-end:") + .last() + .unwrap() + .trim_end() + .trim_start(); + let (start, depth) = rt.cycle_tracker.remove(fn_name).unwrap_or((0, 0)); + // Leftpad by 2 spaces for each depth. + let padding = (0..depth).map(|_| "│ ").collect::(); + log::info!( + "{}└╴{} cycles", + padding, + num_to_comma_separated(rt.state.global_clk - start as u64) + ); + } else { let flush_s = update_io_buf(ctx, fd, s); if !flush_s.is_empty() { flush_s .into_iter() - .for_each(|line| println!("stderr: {}", line)); + .for_each(|line| println!("stdout: {}", line)); } - } else if fd == 3 { - rt.state.public_values_stream.extend_from_slice(slice); - } else if fd == 4 { - rt.state.input_stream.push(slice.to_vec()); - } else { - unreachable!() } + } else if fd == 2 { + let s = core::str::from_utf8(slice).unwrap(); + let flush_s = update_io_buf(ctx, fd, s); + if !flush_s.is_empty() { + flush_s + .into_iter() + .for_each(|line| println!("stderr: {}", line)); + } + } else if fd == 3 { + rt.state.public_values_stream.extend_from_slice(slice); + } else if fd == 4 { + rt.state.input_stream.push(slice.to_vec()); + } else if let Some(hook) = rt.hook_registry.table.get(&fd) { + rt.state.input_stream.extend(hook(rt.hook_env(), slice)); + } else { + log::warn!("tried to write to unknown file descriptor {fd}"); } None } diff --git a/core/src/utils/mod.rs b/core/src/utils/mod.rs index 664674bee8..e768f999a7 100644 --- a/core/src/utils/mod.rs +++ b/core/src/utils/mod.rs @@ -3,6 +3,7 @@ mod config; pub mod ec; mod logger; mod options; +#[cfg(any(test, feature = "programs"))] mod programs; mod prove; mod tracer; @@ -14,7 +15,7 @@ pub use options::*; pub use prove::*; pub use tracer::*; -#[cfg(test)] +#[cfg(any(test, feature = "programs"))] pub use programs::*; use crate::{memory::MemoryCols, operations::field::params::Limbs}; diff --git a/core/src/utils/options.rs b/core/src/utils/options.rs index d823859d1b..8ef087e3ac 100644 --- a/core/src/utils/options.rs +++ b/core/src/utils/options.rs @@ -1,3 +1,8 @@ +use std::env; + +const DEFAULT_SHARD_SIZE: usize = 1 << 22; +const DEFAULT_SHARD_BATCH_SIZE: usize = 16; + #[derive(Debug, Clone, Copy)] pub struct SP1CoreOpts { pub shard_size: usize, @@ -9,8 +14,14 @@ pub struct SP1CoreOpts { impl Default for SP1CoreOpts { fn default() -> Self { Self { - shard_size: 1 << 22, - shard_batch_size: 16, + shard_size: env::var("SHARD_SIZE").map_or_else( + |_| DEFAULT_SHARD_SIZE, + |s| s.parse::().unwrap_or(DEFAULT_SHARD_SIZE), + ), + shard_batch_size: env::var("SHARD_BATCH_SIZE").map_or_else( + |_| DEFAULT_SHARD_BATCH_SIZE, + |s| s.parse::().unwrap_or(DEFAULT_SHARD_BATCH_SIZE), + ), shard_chunking_multiplier: 1, reconstruct_commitments: true, } @@ -21,6 +32,7 @@ impl SP1CoreOpts { pub fn recursion() -> Self { let mut opts = Self::default(); opts.reconstruct_commitments = false; + opts.shard_size = DEFAULT_SHARD_SIZE; opts } } diff --git a/core/src/utils/programs.rs b/core/src/utils/programs.rs index 58af5a08c4..dada2735df 100644 --- a/core/src/utils/programs.rs +++ b/core/src/utils/programs.rs @@ -1,4 +1,3 @@ -#[cfg(test)] pub mod tests { /// Demos. @@ -34,9 +33,6 @@ pub mod tests { pub const ED25519_ELF: &[u8] = include_bytes!("../../../tests/ed25519/elf/riscv32im-succinct-zkvm-elf"); - pub const BLAKE3_COMPRESS_ELF: &[u8] = - include_bytes!("../../../tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf"); - pub const CYCLE_TRACKER_ELF: &[u8] = include_bytes!("../../../tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf"); @@ -97,9 +93,6 @@ pub mod tests { pub const UINT256_MUL_ELF: &[u8] = include_bytes!("../../../tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf"); - pub const UINT256_DIV_ELF: &[u8] = - include_bytes!("../../../tests/uint256-div/elf/riscv32im-succinct-zkvm-elf"); - pub const BLS12381_DECOMPRESS_ELF: &[u8] = include_bytes!("../../../tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf"); diff --git a/core/src/utils/prove.rs b/core/src/utils/prove.rs index 62eaf8b902..8afe5f55b0 100644 --- a/core/src/utils/prove.rs +++ b/core/src/utils/prove.rs @@ -1,6 +1,7 @@ use std::fs::File; use std::io; use std::io::{Seek, Write}; +use std::sync::Arc; use web_time::Instant; pub use baby_bear_blake3::BabyBearBlake3; @@ -14,8 +15,10 @@ use thiserror::Error; use crate::air::MachineAir; use crate::io::{SP1PublicValues, SP1Stdin}; use crate::lookup::InteractionBuilder; -use crate::runtime::ExecutionError; -use crate::runtime::{ExecutionRecord, ShardingConfig}; +use crate::runtime::{ + DefaultSubproofVerifier, ExecutionError, NoOpSubproofVerifier, SubproofVerifier, +}; +use crate::runtime::{ExecutionRecord, ExecutionReport, ShardingConfig}; use crate::stark::DebugConstraintBuilder; use crate::stark::MachineProof; use crate::stark::ProverConstraintFolder; @@ -89,6 +92,24 @@ pub fn prove( config: SC, opts: SP1CoreOpts, ) -> Result<(MachineProof, Vec), SP1CoreProverError> +where + SC::Challenger: Clone, + OpeningProof: Send + Sync, + Com: Send + Sync, + PcsProverData: Send + Sync, + ShardMainData: Serialize + DeserializeOwned, + ::Val: PrimeField32, +{ + prove_with_subproof_verifier::(program, stdin, config, opts, None) +} + +pub fn prove_with_subproof_verifier( + program: Program, + stdin: &SP1Stdin, + config: SC, + opts: SP1CoreOpts, + subproof_verifier: Option>, +) -> Result<(MachineProof, Vec), SP1CoreProverError> where SC::Challenger: Clone, OpeningProof: Send + Sync, @@ -105,6 +126,9 @@ where for proof in stdin.proofs.iter() { runtime.write_proof(proof.0.clone(), proof.1.clone()); } + if let Some(deferred_fn) = subproof_verifier.clone() { + runtime.subproof_verifier = deferred_fn; + } // Setup the machine. let machine = RiscvAir::machine(config); @@ -132,8 +156,8 @@ where let mut checkpoints = Vec::new(); let (public_values_stream, public_values) = loop { // Execute the runtime until we reach a checkpoint. - let (checkpoint, done) = runtime - .execute_state() + let (checkpoint, done) = tracing::info_span!("collect_checkpoints") + .in_scope(|| runtime.execute_state()) .map_err(SP1CoreProverError::ExecutionError)?; // Save the checkpoint to a temp file. @@ -162,8 +186,9 @@ where let mut shard_main_datas = Vec::new(); let mut challenger = machine.config().challenger(); vk.observe_into(&mut challenger); - for checkpoint_file in checkpoints.iter_mut() { - let mut record = trace_checkpoint(program.clone(), checkpoint_file, opts); + for (num, checkpoint_file) in checkpoints.iter_mut().enumerate() { + let (mut record, _) = tracing::info_span!("commit_checkpoint", num) + .in_scope(|| trace_checkpoint(program.clone(), checkpoint_file, opts)); record.public_values = public_values; reset_seek(&mut *checkpoint_file); @@ -185,9 +210,12 @@ where // For each checkpoint, generate events and shard again, then prove the shards. let mut shard_proofs = Vec::>::new(); - for mut checkpoint_file in checkpoints.into_iter() { + let mut report_aggregate = ExecutionReport::default(); + for (num, mut checkpoint_file) in checkpoints.into_iter().enumerate() { let checkpoint_shards = { - let mut events = trace_checkpoint(program.clone(), &checkpoint_file, opts); + let (mut events, report) = tracing::info_span!("prove_checkpoint", num) + .in_scope(|| trace_checkpoint(program.clone(), &checkpoint_file, opts)); + report_aggregate += report; events.public_values = public_values; reset_seek(&mut checkpoint_file); tracing::debug_span!("shard").in_scope(|| machine.shard(events, &sharding_config)) @@ -215,6 +243,23 @@ where .collect::>(); shard_proofs.append(&mut checkpoint_proofs); } + // Log some of the `ExecutionReport` information. + tracing::info!( + "execution report (totals): total_cycles={}, total_syscall_cycles={}", + report_aggregate.total_instruction_count(), + report_aggregate.total_syscall_count() + ); + // Print the opcode and syscall count tables like `du`: + // sorted by count (descending) and with the count in the first column. + tracing::info!("execution report (opcode counts):"); + for line in ExecutionReport::sorted_table_lines(&report_aggregate.opcode_counts) { + tracing::info!(" {line}"); + } + tracing::info!("execution report (syscall counts):"); + for line in ExecutionReport::sorted_table_lines(&report_aggregate.syscall_counts) { + tracing::info!(" {line}"); + } + let proof = MachineProof:: { shard_proofs }; // Print the summary. @@ -326,13 +371,20 @@ where Ok(proof) } -fn trace_checkpoint(program: Program, file: &File, opts: SP1CoreOpts) -> ExecutionRecord { +fn trace_checkpoint( + program: Program, + file: &File, + opts: SP1CoreOpts, +) -> (ExecutionRecord, ExecutionReport) { let mut reader = std::io::BufReader::new(file); let state = bincode::deserialize_from(&mut reader).expect("failed to deserialize state"); let mut runtime = Runtime::recover(program.clone(), state, opts); + // We already passed the deferred proof verifier when creating checkpoints, so the proofs were + // already verified. So here we use a noop verifier to not print any warnings. + runtime.subproof_verifier = Arc::new(NoOpSubproofVerifier); let (events, _) = tracing::debug_span!("runtime.trace").in_scope(|| runtime.execute_record().unwrap()); - events + (events, runtime.report) } fn reset_seek(file: &mut File) { diff --git a/eval/Cargo.toml b/eval/Cargo.toml index b7f7cb19f3..a9d8207870 100644 --- a/eval/Cargo.toml +++ b/eval/Cargo.toml @@ -7,6 +7,6 @@ edition = "2021" sp1-core = { path = "../core" } sp1-prover = { path = "../prover" } -clap = { version = "4.5.4", features = ["derive"] } +clap = { version = "4.5.7", features = ["derive"] } csv = "1.3.0" serde = "1.0.201" diff --git a/examples/Cargo.lock b/examples/Cargo.lock index 76c62068b2..bc5ac9e3ef 100644 --- a/examples/Cargo.lock +++ b/examples/Cargo.lock @@ -88,9 +88,9 @@ checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" [[package]] name = "alloy-primitives" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db8aa973e647ec336810a9356af8aea787249c9d00b1525359f3db29a68d231b" +checksum = "f783611babedbbe90db3478c120fb5f5daacceffc210b39adc0af4fe0da70bad" dependencies = [ "alloy-rlp", "bytes", @@ -120,9 +120,9 @@ dependencies = [ [[package]] name = "alloy-sol-macro" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7dbd17d67f3e89478c8a634416358e539e577899666c927bc3d2b1328ee9b6ca" +checksum = "4bad41a7c19498e3f6079f7744656328699f8ea3e783bdd10d85788cd439f572" dependencies = [ "alloy-sol-macro-expander", "alloy-sol-macro-input", @@ -134,13 +134,13 @@ dependencies = [ [[package]] name = "alloy-sol-macro-expander" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6da95adcf4760bb4b108fefa51d50096c5e5fdd29ee72fed3e86ee414f2e34" +checksum = "fd9899da7d011b4fe4c406a524ed3e3f963797dbc93b45479d60341d3a27b252" dependencies = [ "alloy-sol-macro-input", "const-hex", - "heck 0.4.1", + "heck", "indexmap 2.2.6", "proc-macro-error", "proc-macro2", @@ -152,13 +152,13 @@ dependencies = [ [[package]] name = "alloy-sol-macro-input" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32c8da04c1343871fb6ce5a489218f9c85323c8340a36e9106b5fc98d4dd59d5" +checksum = "d32d595768fdc61331a132b6f65db41afae41b9b97d36c21eb1b955c422a7e60" dependencies = [ "const-hex", "dunce", - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.66", @@ -167,9 +167,9 @@ dependencies = [ [[package]] name = "alloy-sol-types" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a64d2d2395c1ac636b62419a7b17ec39031d6b2367e66e9acbf566e6055e9c" +checksum = "a49042c6d3b66a9fe6b2b5a8bf0d39fc2ae1ee0310a2a26ffedd79fb097878dd" dependencies = [ "alloy-primitives", "alloy-sol-macro", @@ -702,6 +702,12 @@ version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +[[package]] +name = "bytemuck" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78834c15cb5d5efe3452d58b1e8ba890dd62d21907f867f383358198e56ebca5" + [[package]] name = "byteorder" version = "1.5.0" @@ -819,9 +825,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "5db83dced34638ad474f39f250d7fea9598bdd239eaced1bdf45d597da0f433f" dependencies = [ "clap_builder", "clap_derive", @@ -829,9 +835,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "f7e204572485eb3fbf28f871612191521df159bc3e15a9f5064c66dba3a8c05f" dependencies = [ "anstream", "anstyle", @@ -841,11 +847,11 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.4" +version = "4.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +checksum = "c780290ccf4fb26629baa7a1081e68ced113f1d3ec302fa5948f1c381ebf06c6" dependencies = [ - "heck 0.5.0", + "heck", "proc-macro2", "quote", "syn 2.0.66", @@ -2083,12 +2089,6 @@ dependencies = [ "fxhash", ] -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "heck" version = "0.5.0" @@ -2492,6 +2492,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -4081,9 +4090,9 @@ dependencies = [ [[package]] name = "ruint" -version = "1.12.1" +version = "1.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f308135fef9fc398342da5472ce7c484529df23743fb7c734e0f3d472971e62" +checksum = "2c3cc4c2511671f327125da14133d0c5c5d137f006a1017a16f557bc85b16286" dependencies = [ "alloy-rlp", "ark-ff 0.3.0", @@ -4105,9 +4114,9 @@ dependencies = [ [[package]] name = "ruint-macro" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f86854cf50259291520509879a5c294c3c9a4c334e9ff65071c51e42ef1e2343" +checksum = "48fd7bd8a6377e15ad9d42a8ec25371b94ddc67abe7c8b9127bec79bebaaae18" [[package]] name = "rustc-demangle" @@ -4678,13 +4687,14 @@ dependencies = [ "arrayref", "bincode", "blake3", + "bytemuck", "cfg-if", "curve25519-dalek", "elf", "elliptic-curve", "generic-array 1.0.0", "hex", - "itertools 0.12.1", + "itertools 0.13.0", "k256", "log", "nohash-hasher", @@ -4708,6 +4718,7 @@ dependencies = [ "p3-symmetric", "p3-uni-stark", "p3-util", + "rand", "rayon-scan", "rrs-lib", "serde", @@ -4748,7 +4759,7 @@ dependencies = [ name = "sp1-primitives" version = "0.1.0" dependencies = [ - "itertools 0.12.1", + "itertools 0.13.0", "lazy_static", "p3-baby-bear", "p3-field", @@ -4763,18 +4774,21 @@ dependencies = [ "anyhow", "backtrace", "bincode", + "bytemuck", "clap", "dirs", "futures", "hex", "indicatif", - "itertools 0.12.1", + "itertools 0.13.0", "num-bigint 0.4.5", "p3-baby-bear", "p3-bn254-fr", "p3-challenger", "p3-commit", "p3-field", + "p3-util", + "rand", "rayon", "reqwest 0.12.4", "serde", @@ -4802,7 +4816,7 @@ name = "sp1-recursion-circuit" version = "0.1.0" dependencies = [ "bincode", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4824,7 +4838,7 @@ name = "sp1-recursion-compiler" version = "0.1.0" dependencies = [ "backtrace", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4850,7 +4864,7 @@ dependencies = [ "backtrace", "ff 0.13.0", "hashbrown 0.14.5", - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-bn254-fr", @@ -4864,6 +4878,7 @@ dependencies = [ "p3-merkle-tree", "p3-poseidon2", "p3-symmetric", + "p3-util", "serde", "serde_with", "sp1-core", @@ -4887,6 +4902,8 @@ dependencies = [ name = "sp1-recursion-gnark-ffi" version = "0.1.0" dependencies = [ + "anyhow", + "bincode", "bindgen", "cc", "cfg-if", @@ -4894,9 +4911,11 @@ dependencies = [ "num-bigint 0.4.5", "p3-baby-bear", "p3-field", + "p3-symmetric", "rand", "serde", "serde_json", + "sp1-core", "sp1-recursion-compiler", "tempfile", ] @@ -4905,7 +4924,7 @@ dependencies = [ name = "sp1-recursion-program" version = "0.1.0" dependencies = [ - "itertools 0.12.1", + "itertools 0.13.0", "p3-air", "p3-baby-bear", "p3-challenger", @@ -4931,7 +4950,6 @@ dependencies = [ name = "sp1-sdk" version = "0.1.0" dependencies = [ - "alloy-primitives", "alloy-sol-types", "anyhow", "async-trait", @@ -5028,11 +5046,11 @@ dependencies = [ [[package]] name = "strum_macros" -version = "0.26.2" +version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6cf59daf282c0a494ba14fd21610a0325f9f90ec9d1231dea26bcb1d696c946" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck 0.4.1", + "heck", "proc-macro2", "quote", "rustversion", @@ -5084,9 +5102,9 @@ dependencies = [ [[package]] name = "syn-solidity" -version = "0.7.4" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8db114c44cf843a8bacd37a146e37987a0b823a0e8bc4fdc610c9c72ab397a5" +checksum = "8d71e19bca02c807c9faa67b5a47673ff231b6e7449b251695188522f1dc44b2" dependencies = [ "paste", "proc-macro2", @@ -5313,9 +5331,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.37.0" +version = "1.38.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1adbebffeca75fcfd058afa480fb6c0b81e165a0323f9c9d39c9697e37c46787" +checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" dependencies = [ "backtrace", "bytes", @@ -5332,9 +5350,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +checksum = "5f5ae998a069d4b5aba8ee9dad856af7d520c3699e6159b185c2acd48155d39a" dependencies = [ "proc-macro2", "quote", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a6ea63df0d..65d4cb5902 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -13,3 +13,4 @@ members = [ "ssz-withdrawals/script", "tendermint/script", ] +resolver = "2" \ No newline at end of file diff --git a/examples/aggregation/program/Cargo.lock b/examples/aggregation/program/Cargo.lock index 692e2f3662..de974399f6 100644 --- a/examples/aggregation/program/Cargo.lock +++ b/examples/aggregation/program/Cargo.lock @@ -306,6 +306,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "k256" version = "0.13.3" @@ -454,7 +463,7 @@ name = "p3-field" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=3b5265f9d5af36534a46caebf0617595cfb42c5a#3b5265f9d5af36534a46caebf0617595cfb42c5a" dependencies = [ - "itertools", + "itertools 0.12.1", "num-bigint", "num-traits", "p3-util", @@ -467,7 +476,7 @@ name = "p3-matrix" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=3b5265f9d5af36534a46caebf0617595cfb42c5a#3b5265f9d5af36534a46caebf0617595cfb42c5a" dependencies = [ - "itertools", + "itertools 0.12.1", "p3-field", "p3-maybe-rayon", "p3-util", @@ -486,7 +495,7 @@ name = "p3-mds" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=3b5265f9d5af36534a46caebf0617595cfb42c5a#3b5265f9d5af36534a46caebf0617595cfb42c5a" dependencies = [ - "itertools", + "itertools 0.12.1", "p3-dft", "p3-field", "p3-matrix", @@ -512,7 +521,7 @@ name = "p3-symmetric" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=3b5265f9d5af36534a46caebf0617595cfb42c5a#3b5265f9d5af36534a46caebf0617595cfb42c5a" dependencies = [ - "itertools", + "itertools 0.12.1", "p3-field", "serde", ] @@ -763,7 +772,7 @@ dependencies = [ name = "sp1-primitives" version = "0.1.0" dependencies = [ - "itertools", + "itertools 0.13.0", "lazy_static", "p3-baby-bear", "p3-field", @@ -779,6 +788,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "p3-baby-bear", diff --git a/examples/aggregation/program/elf/riscv32im-succinct-zkvm-elf b/examples/aggregation/program/elf/riscv32im-succinct-zkvm-elf index 755e55b012..d256639d5f 100755 Binary files a/examples/aggregation/program/elf/riscv32im-succinct-zkvm-elf and b/examples/aggregation/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/aggregation/script/Cargo.toml b/examples/aggregation/script/Cargo.toml index 1e10877537..b8c50fea51 100644 --- a/examples/aggregation/script/Cargo.toml +++ b/examples/aggregation/script/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] hex = "0.4.3" -sp1-sdk = { path = "../../../sdk", features = ["plonk"] } +sp1-sdk = { path = "../../../sdk" } tracing = "0.1.40" [build-dependencies] diff --git a/examples/chess/program/Cargo.lock b/examples/chess/program/Cargo.lock index 56d9149d8c..484054dd96 100644 --- a/examples/chess/program/Cargo.lock +++ b/examples/chess/program/Cargo.lock @@ -380,6 +380,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -790,6 +796,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand 0.8.5", diff --git a/examples/chess/program/elf/riscv32im-succinct-zkvm-elf b/examples/chess/program/elf/riscv32im-succinct-zkvm-elf index 0466a204d4..6a9bf30352 100755 Binary files a/examples/chess/program/elf/riscv32im-succinct-zkvm-elf and b/examples/chess/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/chess/script/src/main.rs b/examples/chess/script/src/main.rs index 0d3b8f313e..de8f8f9dcc 100644 --- a/examples/chess/script/src/main.rs +++ b/examples/chess/script/src/main.rs @@ -1,4 +1,4 @@ -use sp1_sdk::{ProverClient, SP1Stdin}; +use sp1_sdk::{ProverClient, SP1Proof, SP1Stdin}; const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -24,10 +24,16 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-io.json") + .save("proof-with-io.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-io.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/cycle-tracking/program/Cargo.lock b/examples/cycle-tracking/program/Cargo.lock index b820d2bb26..7c539c31b4 100644 --- a/examples/cycle-tracking/program/Cargo.lock +++ b/examples/cycle-tracking/program/Cargo.lock @@ -298,6 +298,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -641,6 +647,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf b/examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf index 7748da170e..c0c915cd22 100755 Binary files a/examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf and b/examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/cycle-tracking/script/src/main.rs b/examples/cycle-tracking/script/src/main.rs index 85fbe7ceb1..81e534f997 100644 --- a/examples/cycle-tracking/script/src/main.rs +++ b/examples/cycle-tracking/script/src/main.rs @@ -1,4 +1,4 @@ -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; /// The ELF we want to execute inside the zkVM. const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -18,10 +18,16 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save the proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/fibonacci/program/Cargo.lock b/examples/fibonacci/program/Cargo.lock index 8d19871d4e..9a8dc81508 100644 --- a/examples/fibonacci/program/Cargo.lock +++ b/examples/fibonacci/program/Cargo.lock @@ -297,6 +297,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -631,6 +637,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf b/examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf index e67522c9f4..2d2682e462 100755 Binary files a/examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf and b/examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/fibonacci/script/bin/compressed.rs b/examples/fibonacci/script/bin/compressed.rs index 8c0de7973a..9dd3f52690 100644 --- a/examples/fibonacci/script/bin/compressed.rs +++ b/examples/fibonacci/script/bin/compressed.rs @@ -30,7 +30,7 @@ fn main() { // Save the proof. proof - .save("compressed-proof-with-pis.json") + .save("compressed-proof-with-pis.bin") .expect("saving proof failed"); println!("successfully generated and verified proof for the program!") diff --git a/examples/fibonacci/script/bin/plonk_bn254.rs b/examples/fibonacci/script/bin/plonk_bn254.rs index 7057f20fab..2bff4fb21c 100644 --- a/examples/fibonacci/script/bin/plonk_bn254.rs +++ b/examples/fibonacci/script/bin/plonk_bn254.rs @@ -34,7 +34,7 @@ fn main() { // Save the proof. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); println!("successfully generated and verified proof for the program!") diff --git a/examples/fibonacci/script/src/main.rs b/examples/fibonacci/script/src/main.rs index 0aa84a8a12..d7de54e3ab 100644 --- a/examples/fibonacci/script/src/main.rs +++ b/examples/fibonacci/script/src/main.rs @@ -1,4 +1,4 @@ -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; /// The ELF we want to execute inside the zkVM. const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -16,7 +16,7 @@ fn main() { // Generate the proof for the given program and input. let client = ProverClient::new(); let (pk, vk) = client.setup(ELF); - let mut proof = client.prove_compressed(&pk, stdin).unwrap(); + let mut proof = client.prove(&pk, stdin).unwrap(); println!("generated proof"); @@ -29,14 +29,18 @@ fn main() { println!("b: {}", b); // Verify proof and public values - client - .verify_compressed(&proof, &vk) - .expect("verification failed"); + client.verify(&proof, &vk).expect("verification failed"); - // Save the proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/io/program/Cargo.lock b/examples/io/program/Cargo.lock index 60d41a3561..e3fcd779c5 100644 --- a/examples/io/program/Cargo.lock +++ b/examples/io/program/Cargo.lock @@ -298,6 +298,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -632,6 +638,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/io/program/elf/riscv32im-succinct-zkvm-elf b/examples/io/program/elf/riscv32im-succinct-zkvm-elf index e5f2f6b5c0..9c86e401ef 100755 Binary files a/examples/io/program/elf/riscv32im-succinct-zkvm-elf and b/examples/io/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/io/script/src/main.rs b/examples/io/script/src/main.rs index 58389008f7..b4aa20786d 100644 --- a/examples/io/script/src/main.rs +++ b/examples/io/script/src/main.rs @@ -1,5 +1,5 @@ use serde::{Deserialize, Serialize}; -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; /// The ELF we want to execute inside the zkVM. const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -42,10 +42,16 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save the proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/is-prime/script/src/main.rs b/examples/is-prime/script/src/main.rs index 6e63191450..281c361545 100644 --- a/examples/is-prime/script/src/main.rs +++ b/examples/is-prime/script/src/main.rs @@ -1,5 +1,5 @@ //! A program that takes a number `n` as input, and writes if `n` is prime as an output. -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -23,10 +23,17 @@ fn main() { client.verify(&proof, &vk).expect("verification failed"); - // Save the proof + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-is-prime.json") + .save("proof-with-is-prime.bin") .expect("saving proof failed"); + let deserialized_proof = + SP1Proof::load("proof-with-is-prime.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/json/program/Cargo.lock b/examples/json/program/Cargo.lock index ba66915b39..66b23ba7f8 100644 --- a/examples/json/program/Cargo.lock +++ b/examples/json/program/Cargo.lock @@ -306,6 +306,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "lib" version = "0.1.0" @@ -664,6 +670,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/json/program/elf/riscv32im-succinct-zkvm-elf b/examples/json/program/elf/riscv32im-succinct-zkvm-elf index b37c99f583..ad80a3fbb4 100755 Binary files a/examples/json/program/elf/riscv32im-succinct-zkvm-elf and b/examples/json/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/json/script/src/main.rs b/examples/json/script/src/main.rs index ef0abdd71c..2e3a015112 100644 --- a/examples/json/script/src/main.rs +++ b/examples/json/script/src/main.rs @@ -1,7 +1,7 @@ //! A simple script to generate and verify the proof of a given program. use lib::{Account, Transaction}; -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; const JSON_ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -62,10 +62,16 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-io.json") + .save("proof-with-io.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-io.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/regex/program/Cargo.lock b/examples/regex/program/Cargo.lock index ad1e86bd22..d5a4c9bd5d 100644 --- a/examples/regex/program/Cargo.lock +++ b/examples/regex/program/Cargo.lock @@ -299,6 +299,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -670,6 +676,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/regex/program/elf/riscv32im-succinct-zkvm-elf b/examples/regex/program/elf/riscv32im-succinct-zkvm-elf index 3cdf60299e..8114ee471d 100755 Binary files a/examples/regex/program/elf/riscv32im-succinct-zkvm-elf and b/examples/regex/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/regex/script/src/main.rs b/examples/regex/script/src/main.rs index 846233db25..dd8b35d7fa 100644 --- a/examples/regex/script/src/main.rs +++ b/examples/regex/script/src/main.rs @@ -1,4 +1,4 @@ -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; /// The ELF we want to execute inside the zkVM. const REGEX_IO_ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -29,10 +29,16 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save the proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/rsa/program/Cargo.lock b/examples/rsa/program/Cargo.lock index 7b149150da..59a046b519 100644 --- a/examples/rsa/program/Cargo.lock +++ b/examples/rsa/program/Cargo.lock @@ -788,6 +788,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/rsa/program/elf/riscv32im-succinct-zkvm-elf b/examples/rsa/program/elf/riscv32im-succinct-zkvm-elf index 531ed1e072..688b17c79b 100755 Binary files a/examples/rsa/program/elf/riscv32im-succinct-zkvm-elf and b/examples/rsa/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/rsa/script/src/main.rs b/examples/rsa/script/src/main.rs index f6c1c31737..4b3f237308 100644 --- a/examples/rsa/script/src/main.rs +++ b/examples/rsa/script/src/main.rs @@ -2,7 +2,7 @@ use rsa::{ pkcs8::{DecodePrivateKey, DecodePublicKey}, RsaPrivateKey, RsaPublicKey, }; -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; use std::vec; /// The ELF we want to execute inside the zkVM. @@ -59,10 +59,14 @@ fn main() { // Verify proof. client.verify(&proof, &vk).expect("verification failed"); - // Save the proof. - proof - .save("proof-with-pis.json") - .expect("saving proof failed"); + // Test a round trip of proof serialization and deserialization. + proof.save("proof-with-pis").expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/ssz-withdrawals/program/Cargo.lock b/examples/ssz-withdrawals/program/Cargo.lock index 874f71e34b..81a8dedbc7 100644 --- a/examples/ssz-withdrawals/program/Cargo.lock +++ b/examples/ssz-withdrawals/program/Cargo.lock @@ -1426,6 +1426,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/ssz-withdrawals/program/elf/riscv32im-succinct-zkvm-elf b/examples/ssz-withdrawals/program/elf/riscv32im-succinct-zkvm-elf index a4b95a24cf..599bd4066d 100755 Binary files a/examples/ssz-withdrawals/program/elf/riscv32im-succinct-zkvm-elf and b/examples/ssz-withdrawals/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/ssz-withdrawals/script/src/main.rs b/examples/ssz-withdrawals/script/src/main.rs index 9713ca1cd7..34b1013835 100644 --- a/examples/ssz-withdrawals/script/src/main.rs +++ b/examples/ssz-withdrawals/script/src/main.rs @@ -1,4 +1,4 @@ -use sp1_sdk::{utils, ProverClient, SP1Stdin}; +use sp1_sdk::{utils, ProverClient, SP1Proof, SP1Stdin}; const ELF: &[u8] = include_bytes!("../../program/elf/riscv32im-succinct-zkvm-elf"); @@ -10,17 +10,21 @@ fn main() { let stdin = SP1Stdin::new(); let client = ProverClient::new(); let (pk, vk) = client.setup(ELF); - let proof = client.prove_compressed(&pk, stdin).expect("proving failed"); + let proof = client.prove(&pk, stdin).expect("proving failed"); // Verify proof. - client - .verify_compressed(&proof, &vk) - .expect("verification failed"); + client.verify(&proof, &vk).expect("verification failed"); - // Save proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/examples/tendermint/program/Cargo.lock b/examples/tendermint/program/Cargo.lock index 780d188934..ed1af56c23 100644 --- a/examples/tendermint/program/Cargo.lock +++ b/examples/tendermint/program/Cargo.lock @@ -463,6 +463,12 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" version = "0.2.154" @@ -935,6 +941,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", diff --git a/examples/tendermint/program/elf/riscv32im-succinct-zkvm-elf b/examples/tendermint/program/elf/riscv32im-succinct-zkvm-elf index 8227fc8b85..a9159cc25c 100755 Binary files a/examples/tendermint/program/elf/riscv32im-succinct-zkvm-elf and b/examples/tendermint/program/elf/riscv32im-succinct-zkvm-elf differ diff --git a/examples/tendermint/script/src/main.rs b/examples/tendermint/script/src/main.rs index 9fe31874a1..e2fa14837d 100644 --- a/examples/tendermint/script/src/main.rs +++ b/examples/tendermint/script/src/main.rs @@ -1,3 +1,4 @@ +use sp1_sdk::SP1Proof; use std::time::Duration; use tokio::runtime::Runtime; @@ -62,12 +63,10 @@ fn main() { let client = ProverClient::new(); let (pk, vk) = client.setup(TENDERMINT_ELF); - let proof = client.prove_compressed(&pk, stdin).expect("proving failed"); + let proof = client.prove(&pk, stdin).expect("proving failed"); // Verify proof. - client - .verify_compressed(&proof, &vk) - .expect("verification failed"); + client.verify(&proof, &vk).expect("verification failed"); // Verify the public values let mut expected_public_values: Vec = Vec::new(); @@ -75,15 +74,18 @@ fn main() { expected_public_values.extend(light_block_2.signed_header.header.hash().as_bytes()); expected_public_values.extend(serde_cbor::to_vec(&expected_verdict).unwrap()); - assert_eq!( - proof.public_values.as_ref(), - expected_public_values - ); + assert_eq!(proof.public_values.as_ref(), expected_public_values); - // Save proof. + // Test a round trip of proof serialization and deserialization. proof - .save("proof-with-pis.json") + .save("proof-with-pis.bin") .expect("saving proof failed"); + let deserialized_proof = SP1Proof::load("proof-with-pis.bin").expect("loading proof failed"); + + // Verify the deserialized proof. + client + .verify(&deserialized_proof, &vk) + .expect("verification failed"); println!("successfully generated and verified proof for the program!") } diff --git a/primitives/Cargo.toml b/primitives/Cargo.toml index 1f48703fb2..371b9f46a6 100644 --- a/primitives/Cargo.toml +++ b/primitives/Cargo.toml @@ -9,4 +9,4 @@ p3-field = { workspace = true } p3-baby-bear = { workspace = true } p3-poseidon2 = { workspace = true } p3-symmetric = { workspace = true } -itertools = "0.12.1" +itertools = "0.13.0" diff --git a/prover/Cargo.toml b/prover/Cargo.toml index 3bb4890ee5..9a508fb74a 100644 --- a/prover/Cargo.toml +++ b/prover/Cargo.toml @@ -16,22 +16,23 @@ p3-challenger = { workspace = true } p3-baby-bear = { workspace = true } p3-bn254-fr = { workspace = true } p3-commit = { workspace = true } +p3-util = { workspace = true } bincode = "1.3.3" serde = { version = "1.0", features = ["derive", "rc"] } backtrace = "0.3.71" rayon = "1.10.0" -itertools = "0.12.1" +itertools = "0.13.0" tracing = "0.1.40" tracing-subscriber = "0.3.18" serde_json = "1.0.117" -clap = { version = "4.5.4", features = ["derive", "env"] } +clap = { version = "4.5.7", features = ["derive", "env"] } sha2 = "0.10.8" hex = "0.4.3" anyhow = "1.0.83" size = "0.4.1" dirs = "5.0.1" tempfile = "3.10.1" -tokio = { version = "1.37.0", features = ["full"] } +tokio = { version = "1.38.0", features = ["full"] } reqwest = { version = "0.12.4", features = [ "rustls-tls", "trust-dns", @@ -43,6 +44,8 @@ subtle-encoding = "0.5.1" serial_test = "3.1.1" num-bigint = "0.4.5" thiserror = "1.0.60" +bytemuck = "1.16.0" +rand = "0.8.4" [[bin]] name = "build_plonk_bn254" @@ -54,4 +57,4 @@ path = "scripts/e2e.rs" [features] neon = ["sp1-core/neon"] -plonk = ["sp1-recursion-gnark-ffi/plonk"] +native-gnark = ["sp1-recursion-gnark-ffi/native"] diff --git a/prover/Makefile b/prover/Makefile index a63ebbbea0..b4e0999616 100644 --- a/prover/Makefile +++ b/prover/Makefile @@ -6,12 +6,13 @@ build-plonk-bn254: rm -rf build && \ mkdir -p build && \ RUSTFLAGS='-C target-cpu=native' \ - cargo run -p sp1-prover --release --bin build_plonk_bn254 --features plonk -- \ + cargo run -p sp1-prover --release --bin build_plonk_bn254 -- \ --build-dir=./build release-plonk-bn254: - bash release.sh + @read -p "Release version (ex. v1.0.0-testnet)? " version; \ + bash release.sh $$version test-e2e: RUSTFLAGS='-C target-cpu=native' \ - cargo test --package sp1-prover --lib --release -- tests::test_e2e --exact --show-output \ No newline at end of file + cargo test --package sp1-prover --lib --release -- tests::test_e2e --exact --show-output diff --git a/prover/release.sh b/prover/release.sh index 0d8601c067..61ffc55022 100644 --- a/prover/release.sh +++ b/prover/release.sh @@ -1,6 +1,9 @@ #!/bin/bash set -e +# Get the version from the command line. +VERSION=$1 + # Specify the file to upload and the S3 bucket name FILE_TO_UPLOAD="./build" S3_BUCKET="sp1-circuits" @@ -18,8 +21,11 @@ if [ $? -ne 0 ]; then exit 1 fi +# Put the version in the build directory +echo "$COMMIT_HASH $VERSION" > ./build/SP1_COMMIT + # Create archive named after the commit hash -ARCHIVE_NAME="${COMMIT_HASH}.tar.gz" +ARCHIVE_NAME="${VERSION}.tar.gz" cd $FILE_TO_UPLOAD tar --exclude='srs.bin' --exclude='srs_lagrange.bin' -czvf "../$ARCHIVE_NAME" . cd - @@ -35,4 +41,4 @@ if [ $? -ne 0 ]; then exit 1 fi -echo "succesfully uploaded build artifacts to s3://$S3_BUCKET/$ARCHIVE_NAME" \ No newline at end of file +echo "succesfully uploaded build artifacts to s3://$S3_BUCKET/$ARCHIVE_NAME" diff --git a/prover/src/build.rs b/prover/src/build.rs index 129ec454a8..af7ab82172 100644 --- a/prover/src/build.rs +++ b/prover/src/build.rs @@ -12,9 +12,9 @@ use sp1_recursion_core::air::RecursionPublicValues; pub use sp1_recursion_core::stark::utils::sp1_dev_mode; use sp1_recursion_gnark_ffi::PlonkBn254Prover; -use crate::install::{install_plonk_bn254_artifacts, PLONK_BN254_ARTIFACTS_COMMIT}; +use crate::install::install_plonk_bn254_artifacts; use crate::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes}; -use crate::{OuterSC, SP1Prover}; +use crate::{OuterSC, SP1Prover, SP1_CIRCUIT_VERSION}; /// Tries to install the PLONK artifacts if they are not already installed. pub fn try_install_plonk_bn254_artifacts() -> PathBuf { @@ -27,8 +27,8 @@ pub fn try_install_plonk_bn254_artifacts() -> PathBuf { ); } else { println!( - "[sp1] plonk bn254 artifacts for commit {} do not exist at {}. downloading...", - PLONK_BN254_ARTIFACTS_COMMIT, + "[sp1] plonk bn254 artifacts for version {} do not exist at {}. downloading...", + SP1_CIRCUIT_VERSION, build_dir.display() ); install_plonk_bn254_artifacts(build_dir.clone()); @@ -37,9 +37,6 @@ pub fn try_install_plonk_bn254_artifacts() -> PathBuf { } /// Tries to build the PLONK artifacts inside the development directory. -/// -/// TODO: Maybe add some additional logic here to handle rebuilding the artifacts if they are -/// already built. pub fn try_build_plonk_bn254_artifacts_dev( template_vk: &StarkVerifyingKey, template_proof: &ShardProof, @@ -57,7 +54,7 @@ fn plonk_bn254_artifacts_dir() -> PathBuf { .join(".sp1") .join("circuits") .join("plonk_bn254") - .join(PLONK_BN254_ARTIFACTS_COMMIT) + .join(SP1_CIRCUIT_VERSION) } /// Gets the directory where the PLONK artifacts are installed in development mode. diff --git a/prover/src/install.rs b/prover/src/install.rs index 873c50f685..8971661bd7 100644 --- a/prover/src/install.rs +++ b/prover/src/install.rs @@ -4,14 +4,11 @@ use futures::StreamExt; use indicatif::{ProgressBar, ProgressStyle}; use reqwest::Client; -use crate::utils::block_on; +use crate::{utils::block_on, SP1_CIRCUIT_VERSION}; /// The base URL for the S3 bucket containing the plonk bn254 artifacts. pub const PLONK_BN254_ARTIFACTS_URL_BASE: &str = "https://sp1-circuits.s3-us-east-2.amazonaws.com"; -/// The current version of the plonk bn254 artifacts. -pub const PLONK_BN254_ARTIFACTS_COMMIT: &str = "e48c01ec"; - /// Install the latest plonk bn254 artifacts. /// /// This function will download the latest plonk bn254 artifacts from the S3 bucket and extract them to @@ -23,7 +20,7 @@ pub fn install_plonk_bn254_artifacts(build_dir: PathBuf) { // Download the artifacts. let download_url = format!( "{}/{}.tar.gz", - PLONK_BN254_ARTIFACTS_URL_BASE, PLONK_BN254_ARTIFACTS_COMMIT + PLONK_BN254_ARTIFACTS_URL_BASE, SP1_CIRCUIT_VERSION ); let mut artifacts_tar_gz_file = tempfile::NamedTempFile::new().expect("failed to create tempfile"); @@ -63,7 +60,7 @@ pub fn install_plonk_bn254_artifacts_dir() -> PathBuf { .unwrap() .join(".sp1") .join("circuits") - .join(PLONK_BN254_ARTIFACTS_COMMIT) + .join(SP1_CIRCUIT_VERSION) } /// Download the file with a progress bar that indicates the progress. diff --git a/prover/src/lib.rs b/prover/src/lib.rs index 73e3f27c0b..8df7eed313 100644 --- a/prover/src/lib.rs +++ b/prover/src/lib.rs @@ -18,6 +18,7 @@ pub mod verify; use std::borrow::Borrow; use std::path::Path; +use std::sync::Arc; use p3_baby_bear::BabyBear; use p3_challenger::CanObserve; @@ -60,6 +61,8 @@ use tracing::instrument; pub use types::*; use utils::words_to_bytes; +pub use sp1_core::SP1_CIRCUIT_VERSION; + /// The configuration for the core prover. pub type CoreSC = BabyBearPoseidon2; @@ -135,6 +138,12 @@ pub struct SP1Prover { /// The machine used for proving the wrapping step. pub wrap_machine: StarkMachine::Val>>, + + /// The options for the core prover. + pub core_opts: SP1CoreOpts, + + /// The options for the recursion prover. + pub recursion_opts: SP1CoreOpts, } impl SP1Prover { @@ -192,6 +201,8 @@ impl SP1Prover { compress_machine, shrink_machine, wrap_machine, + core_opts: SP1CoreOpts::default(), + recursion_opts: SP1CoreOpts::recursion(), } } @@ -239,8 +250,13 @@ impl SP1Prover { ) -> Result { let config = CoreSC::default(); let program = Program::from(&pk.elf); - let opts = SP1CoreOpts::default(); - let (proof, public_values_stream) = sp1_core::utils::prove(program, stdin, config, opts)?; + let (proof, public_values_stream) = sp1_core::utils::prove_with_subproof_verifier( + program, + stdin, + config, + self.core_opts, + Some(Arc::new(self)), + )?; let public_values = SP1PublicValues::from(&public_values_stream); Ok(SP1CoreProof { proof: SP1CoreProofData(proof.shard_proofs), @@ -394,39 +410,19 @@ impl SP1Prover { batch_size, ); - let mut first_layer_proofs = Vec::new(); - let opts = SP1CoreOpts::recursion(); + let mut reduce_proofs = Vec::new(); + let opts = self.recursion_opts; let shard_batch_size = opts.shard_batch_size; for inputs in core_inputs.chunks(shard_batch_size) { let proofs = inputs .into_par_iter() .map(|input| { - let mut runtime = RecursionRuntime::, Challenge, _>::new( - &self.recursion_program, - self.compress_machine.config().perm.clone(), - ); - - let mut witness_stream = Vec::new(); - witness_stream.extend(input.write()); - - runtime.witness_stream = witness_stream.into(); - runtime.run(); - runtime.print_stats(); - - let pk = &self.rec_pk; - let mut recursive_challenger = self.compress_machine.config().challenger(); - ( - self.compress_machine.prove::>( - pk, - runtime.record, - &mut recursive_challenger, - opts, - ), - ReduceProgramType::Core, - ) + let proof = + self.compress_machine_proof(input, &self.recursion_program, &self.rec_pk); + (proof, ReduceProgramType::Core) }) .collect::>(); - first_layer_proofs.extend(proofs); + reduce_proofs.extend(proofs); } // Run the deferred proofs programs. @@ -434,40 +430,17 @@ impl SP1Prover { let proofs = inputs .into_par_iter() .map(|input| { - let mut runtime = RecursionRuntime::, Challenge, _>::new( + let proof = self.compress_machine_proof( + input, &self.deferred_program, - self.compress_machine.config().perm.clone(), + &self.deferred_pk, ); - - let mut witness_stream = Vec::new(); - witness_stream.extend(input.write()); - - runtime.witness_stream = witness_stream.into(); - runtime.run(); - runtime.print_stats(); - - let pk = &self.deferred_pk; - let mut recursive_challenger = self.compress_machine.config().challenger(); - ( - self.compress_machine.prove::>( - pk, - runtime.record, - &mut recursive_challenger, - opts, - ), - ReduceProgramType::Deferred, - ) + (proof, ReduceProgramType::Deferred) }) .collect::>(); - first_layer_proofs.extend(proofs); + reduce_proofs.extend(proofs); } - // Chain all the individual shard proofs. - let mut reduce_proofs = first_layer_proofs - .into_iter() - .flat_map(|(proof, kind)| proof.shard_proofs.into_iter().map(move |p| (p, kind))) - .collect::>(); - // Iterate over the recursive proof batches until there is one proof remaining. let mut is_complete; loop { @@ -535,7 +508,7 @@ impl SP1Prover { runtime.run(); runtime.print_stats(); - let opts = SP1CoreOpts::recursion(); + let opts = self.recursion_opts; let mut recursive_challenger = self.compress_machine.config().challenger(); self.compress_machine .prove::>(pk, runtime.record, &mut recursive_challenger, opts) @@ -572,7 +545,7 @@ impl SP1Prover { tracing::debug!("Compress program executed successfully"); // Prove the compress program. - let opts = SP1CoreOpts::recursion(); + let opts = self.recursion_opts; let mut compress_challenger = self.shrink_machine.config().challenger(); let mut compress_proof = self.shrink_machine.prove::>( &self.shrink_pk, @@ -613,7 +586,7 @@ impl SP1Prover { tracing::debug!("Wrap program executed successfully"); // Prove the wrap program. - let opts = SP1CoreOpts::recursion(); + let opts = self.recursion_opts; let mut wrap_challenger = self.wrap_machine.config().challenger(); let time = std::time::Instant::now(); let mut wrap_proof = self.wrap_machine.prove::>( @@ -696,10 +669,10 @@ mod tests { use std::fs::File; use std::io::{Read, Write}; - use self::build::try_build_plonk_bn254_artifacts_dev; use super::*; use anyhow::Result; + use build::try_build_plonk_bn254_artifacts_dev; use p3_field::PrimeField32; use serial_test::serial; use sp1_core::io::SP1Stdin; @@ -718,7 +691,8 @@ mod tests { let elf = include_bytes!("../../tests/fibonacci/elf/riscv32im-succinct-zkvm-elf"); tracing::info!("initializing prover"); - let prover = SP1Prover::new(); + let mut prover = SP1Prover::new(); + prover.core_opts.shard_size = 1 << 12; tracing::info!("setup elf"); let (pk, vk) = prover.setup(elf); @@ -748,11 +722,11 @@ mod tests { let bytes = bincode::serialize(&wrapped_bn254_proof).unwrap(); // Save the proof. - let mut file = File::create("proof-with-pis.json").unwrap(); + let mut file = File::create("proof-with-pis.bin").unwrap(); file.write_all(bytes.as_slice()).unwrap(); // Load the proof. - let mut file = File::open("proof-with-pis.json").unwrap(); + let mut file = File::open("proof-with-pis.bin").unwrap(); let mut bytes = Vec::new(); file.read_to_end(&mut bytes).unwrap(); diff --git a/prover/src/verify.rs b/prover/src/verify.rs index fedd467cfc..95865e255d 100644 --- a/prover/src/verify.rs +++ b/prover/src/verify.rs @@ -4,6 +4,8 @@ use anyhow::Result; use num_bigint::BigUint; use p3_baby_bear::BabyBear; use p3_field::{AbstractField, PrimeField}; +use sp1_core::air::MachineAir; +use sp1_core::runtime::SubproofVerifier; use sp1_core::{ air::PublicValues, io::SP1PublicValues, @@ -45,7 +47,7 @@ impl SP1Prover { self.core_machine .verify(&vk.vk, &machine_proof, &mut challenger)?; - // Verify shard transitions + // Verify shard transitions. for (i, shard_proof) in proof.0.iter().enumerate() { let public_values = PublicValues::from_vec(shard_proof.public_values.clone()); // Verify shard transitions @@ -100,6 +102,41 @@ impl SP1Prover { } } + // Verify that the number of shards is not too large. + if proof.0.len() > 1 << 16 { + return Err(MachineVerificationError::TooManyShards); + } + + // Verify that the `MemoryInit` and `MemoryFinalize` chips are the last chips in the proof. + for (i, shard_proof) in proof.0.iter().enumerate() { + let chips = self + .core_machine + .shard_chips_ordered(&shard_proof.chip_ordering) + .collect::>(); + let memory_init_count = chips + .clone() + .into_iter() + .filter(|chip| chip.name() == "MemoryInit") + .count(); + let memory_final_count = chips + .into_iter() + .filter(|chip| chip.name() == "MemoryFinalize") + .count(); + + // Assert that the `MemoryInit` and `MemoryFinalize` chips only exist in the last shard. + if i != 0 && (memory_final_count > 0 || memory_init_count > 0) { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory init and finalize should not eixst anywhere but the last chip" + .to_string(), + )); + } + if i == 0 && (memory_init_count != 1 || memory_final_count != 1) { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory init and finalize should exist the last chip".to_string(), + )); + } + } + Ok(()) } @@ -259,3 +296,37 @@ pub fn verify_plonk_bn254_public_inputs( Ok(()) } + +impl SubproofVerifier for &SP1Prover { + fn verify_deferred_proof( + &self, + proof: &sp1_core::stark::ShardProof, + vk: &sp1_core::stark::StarkVerifyingKey, + vk_hash: [u32; 8], + committed_value_digest: [u32; 8], + ) -> Result<(), MachineVerificationError> { + // Check that the vk hash matches the vk hash from the input. + if vk.hash_u32() != vk_hash { + return Err(MachineVerificationError::InvalidPublicValues( + "vk hash from syscall does not match vkey from input", + )); + } + // Check that proof is valid. + self.verify_compressed( + &SP1ReduceProof { + proof: proof.clone(), + }, + &SP1VerifyingKey { vk: vk.clone() }, + )?; + // Check that the committed value digest matches the one from syscall + let public_values: &RecursionPublicValues<_> = proof.public_values.as_slice().borrow(); + for (i, word) in public_values.committed_value_digest.iter().enumerate() { + if *word != committed_value_digest[i].into() { + return Err(MachineVerificationError::InvalidPublicValues( + "committed_value_digest does not match", + )); + } + } + Ok(()) + } +} diff --git a/recursion/circuit/Cargo.toml b/recursion/circuit/Cargo.toml index 8843d11859..fc2b0d8441 100644 --- a/recursion/circuit/Cargo.toml +++ b/recursion/circuit/Cargo.toml @@ -12,7 +12,7 @@ p3-matrix = { workspace = true } p3-util = { workspace = true } sp1-recursion-core = { path = "../core" } sp1-core = { path = "../../core" } -itertools = "0.12.1" +itertools = "0.13.0" serde = { version = "1.0.201", features = ["derive"] } sp1-recursion-derive = { path = "../derive" } sp1-recursion-compiler = { path = "../compiler" } @@ -31,3 +31,6 @@ p3-poseidon2 = { workspace = true } zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } rand = "0.8.5" sp1-recursion-gnark-ffi = { path = "../gnark-ffi" } + +[features] +native-gnark = ["sp1-recursion-gnark-ffi/native"] diff --git a/recursion/circuit/src/poseidon2.rs b/recursion/circuit/src/poseidon2.rs index a5a8cc1136..792754014d 100644 --- a/recursion/circuit/src/poseidon2.rs +++ b/recursion/circuit/src/poseidon2.rs @@ -1,5 +1,7 @@ //! An implementation of Poseidon2 over BN254. +use std::array; + use itertools::Itertools; use p3_field::AbstractField; use p3_field::Field; @@ -16,6 +18,8 @@ pub trait Poseidon2CircuitBuilder { fn p2_permute_mut(&mut self, state: [Var; SPONGE_SIZE]); fn p2_hash(&mut self, input: &[Felt]) -> OuterDigestVariable; fn p2_compress(&mut self, input: [OuterDigestVariable; 2]) -> OuterDigestVariable; + fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]); + fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8]; } impl Poseidon2CircuitBuilder for Builder { @@ -52,6 +56,24 @@ impl Poseidon2CircuitBuilder for Builder { self.p2_permute_mut(state); [state[0]; DIGEST_SIZE] } + + fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]) { + self.push(DslIr::CircuitPoseidon2PermuteBabyBear(state)); + } + + fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8] { + let mut state: [Felt; 16] = array::from_fn(|_| self.eval(C::F::zero())); + + for block_chunk in &input.iter().chunks(8) { + state + .iter_mut() + .zip(block_chunk) + .for_each(|(s, i)| *s = self.eval(*i)); + self.p2_babybear_permute_mut(state); + } + + array::from_fn(|i| state[i]) + } } #[cfg(test)] @@ -60,6 +82,9 @@ pub mod tests { use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; use p3_symmetric::{CryptographicHasher, Permutation, PseudoCompressionFunction}; + use rand::thread_rng; + use rand::Rng; + use sp1_core::utils::{inner_perm, InnerHash}; use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::constraints::ConstraintCompiler; use sp1_recursion_compiler::ir::{Builder, Felt, Var, Witness}; @@ -95,6 +120,25 @@ pub mod tests { PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } + #[test] + fn test_p2_babybear_permute_mut() { + let mut rng = thread_rng(); + let mut builder = Builder::::default(); + let input: [BabyBear; 16] = [rng.gen(); 16]; + let input_vars: [Felt<_>; 16] = input.map(|x| builder.eval(x)); + builder.p2_babybear_permute_mut(input_vars); + + let perm = inner_perm(); + let result = perm.permute(input); + for i in 0..16 { + builder.assert_felt_eq(input_vars[i], result[i]); + } + + let mut backend = ConstraintCompiler::::default(); + let constraints = backend.emit(builder.operations); + PlonkBn254Prover::test::(constraints.clone(), Witness::default()); + } + #[test] fn test_p2_hash() { let perm = outer_perm(); @@ -147,4 +191,53 @@ pub mod tests { let constraints = backend.emit(builder.operations); PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } + + #[test] + fn test_p2_babybear_hash() { + let perm = inner_perm(); + let hasher = InnerHash::new(perm.clone()); + + let input: [BabyBear; 26] = [ + BabyBear::from_canonical_u32(0), + BabyBear::from_canonical_u32(1), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + ]; + let output = hasher.hash_iter(input); + println!("{:?}", output); + + let mut builder = Builder::::default(); + let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x)); + let result = builder.p2_babybear_hash(input_felts.as_slice()); + + for i in 0..8 { + builder.assert_felt_eq(result[i], output[i]); + } + + let mut backend = ConstraintCompiler::::default(); + let constraints = backend.emit(builder.operations); + PlonkBn254Prover::test::(constraints.clone(), Witness::default()); + } } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 48c800aa13..574a801411 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use std::marker::PhantomData; use crate::fri::verify_two_adic_pcs; +use crate::poseidon2::Poseidon2CircuitBuilder; use crate::types::OuterDigestVariable; use crate::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes}; use crate::witness::Witnessable; @@ -20,7 +21,7 @@ use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler}; use sp1_recursion_compiler::ir::{Builder, Config, Ext, Felt, Var}; use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; -use sp1_recursion_core::air::RecursionPublicValues; +use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}; use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer}; use sp1_recursion_core::stark::RecursionAirSkinnyDeg9; use sp1_recursion_program::commit::PolynomialSpaceVariable; @@ -270,7 +271,9 @@ pub fn build_wrap_circuit( let element = builder.get(&proof.public_values, i); pv_elements.push(element); } + let pv: &RecursionPublicValues<_> = pv_elements.as_slice().borrow(); + let one_felt: Felt<_> = builder.constant(BabyBear::one()); // Proof must be complete. In the reduce program, this will ensure that the SP1 proof has been // fully accumulated. @@ -347,6 +350,13 @@ pub fn build_wrap_circuit( } builder.assert_ext_eq(cumulative_sum, zero_ext); + // Verify the public values digest. + let calculated_digest = builder.p2_babybear_hash(&pv_elements[0..NUM_PV_ELMS_TO_HASH]); + let expected_digest = pv.digest; + for (calculated_elm, expected_elm) in calculated_digest.iter().zip(expected_digest.iter()) { + builder.assert_felt_eq(*expected_elm, *calculated_elm); + } + let mut backend = ConstraintCompiler::::default(); backend.emit(builder.operations) } diff --git a/recursion/compiler/Cargo.toml b/recursion/compiler/Cargo.toml index 736013ff09..70aee2868f 100644 --- a/recursion/compiler/Cargo.toml +++ b/recursion/compiler/Cargo.toml @@ -13,7 +13,7 @@ p3-symmetric = { workspace = true } p3-util = { workspace = true } sp1-recursion-core = { path = "../core" } sp1-core = { path = "../../core" } -itertools = "0.12.1" +itertools = "0.13.0" serde = { version = "1.0.201", features = ["derive"] } sp1-recursion-derive = { path = "../derive" } p3-bn254-fr = { workspace = true } diff --git a/recursion/compiler/src/asm/compiler.rs b/recursion/compiler/src/asm/compiler.rs index 5c2c9a22bf..3ec01ebe48 100644 --- a/recursion/compiler/src/asm/compiler.rs +++ b/recursion/compiler/src/asm/compiler.rs @@ -536,6 +536,12 @@ impl + TwoAdicField> AsmCo DslIr::Halt => { self.push(AsmInstruction::Halt, trace); } + DslIr::ExpReverseBitsLen(base, ptr, len) => { + self.push( + AsmInstruction::ExpReverseBitsLen(base.fp(), ptr.fp(), len.fp()), + trace, + ); + } _ => unimplemented!(), } } @@ -594,7 +600,7 @@ impl + TwoAdicField> AsmCo pub fn compile(self) -> RecursionProgram { let code = self.code(); - tracing::info!("recursion program size: {}", code.size()); + tracing::debug!("recursion program size: {}", code.size()); code.machine_code() } diff --git a/recursion/compiler/src/asm/instruction.rs b/recursion/compiler/src/asm/instruction.rs index 0d965e6fee..78befc94d4 100644 --- a/recursion/compiler/src/asm/instruction.rs +++ b/recursion/compiler/src/asm/instruction.rs @@ -167,18 +167,24 @@ pub enum AsmInstruction { /// Hint a vector of blocks. Hint(i32), - // FRIFold(m, input). + /// FRIFold(m, input). FriFold(i32, i32), - // Commit(val, index). + /// Commit(val, index). Commit(i32, i32), - // RegisterPublicValue(val). + /// RegisterPublicValue(val). RegisterPublicValue(i32), LessThan(i32, i32, i32), CycleTracker(String), + + /// ExpReverseBitsLen instruction: (mathematical description) given `x`, `exp`, `len`, bit-reverse the last `len` bits of + /// `exp` and raise `x` to the power of the resulting value. The arguments are a pointer to the + /// addresss at which `x` is located (will be written to with the result), a pointer to the + /// address containing the bits of `exp` stored as a little-endian bit array, and `len`. + ExpReverseBitsLen(i32, i32, i32), } impl> AsmInstruction { @@ -862,6 +868,17 @@ impl> AsmInstruction { true, "".to_string(), ), + AsmInstruction::ExpReverseBitsLen(base, ptr, len) => Instruction::new( + Opcode::ExpReverseBitsLen, + i32_f(base), + i32_f_arr(ptr), + i32_f_arr(len), + F::zero(), + F::zero(), + false, + false, + "".to_string(), + ), } } @@ -1136,6 +1153,13 @@ impl> AsmInstruction { AsmInstruction::CycleTracker(name) => { write!(f, "cycle-tracker {}", name) } + AsmInstruction::ExpReverseBitsLen(base, ptr, len) => { + write!( + f, + "exp_reverse_bits_len ({})fp, ({})fp, ({})fp", + base, ptr, len + ) + } } } } diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index c5c67647a9..eb43951358 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -249,6 +249,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::Permute, args: state.iter().map(|x| vec![x.id()]).collect(), }), + DslIr::CircuitPoseidon2PermuteBabyBear(state) => constraints.push(Constraint { + opcode: ConstraintOpcode::PermuteBabyBear, + args: state.iter().map(|x| vec![x.id()]).collect(), + }), DslIr::CircuitSelectV(cond, a, b, out) => { constraints.push(Constraint { opcode: ConstraintOpcode::SelectV, diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 581b4558da..4911e0f108 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,4 +46,5 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + PermuteBabyBear, } diff --git a/recursion/compiler/src/ir/bits.rs b/recursion/compiler/src/ir/bits.rs index f69c8cee1d..396fb92be4 100644 --- a/recursion/compiler/src/ir/bits.rs +++ b/recursion/compiler/src/ir/bits.rs @@ -26,6 +26,15 @@ impl Builder { output } + /// Range checks a variable to a certain number of bits. + pub fn range_check_v(&mut self, num: Var, num_bits: usize) { + let bits = self.num2bits_v(num); + self.range(num_bits, bits.len()).for_each(|i, builder| { + let bit = builder.get(&bits, i); + builder.assert_var_eq(bit, C::N::zero()); + }); + } + /// Converts a variable to bits inside a circuit. pub fn num2bits_v_circuit(&mut self, num: Var, bits: usize) -> Vec> { let mut output = Vec::new(); diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index f7a5cee3e0..f5c2a1b856 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -201,6 +201,8 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only /// be used when target is a gnark circuit. CircuitPoseidon2Permute([Var; 3]), + /// Permutates an array of BabyBear elements in the circuit. + CircuitPoseidon2PermuteBabyBear([Felt; 16]), // Miscellaneous instructions. /// Decompose hint operation of a usize into an array. (output = num2bits(usize)). @@ -277,4 +279,7 @@ pub enum DslIr { LessThan(Var, Var, Var), /// Tracks the number of cycles used by a block of code annotated by the string input. CycleTracker(String), + + // Reverse bits exponentiation. + ExpReverseBitsLen(Ptr, Var, Var), } diff --git a/recursion/compiler/src/ir/utils.rs b/recursion/compiler/src/ir/utils.rs index bea55c21b4..3c8a6b5657 100644 --- a/recursion/compiler/src/ir/utils.rs +++ b/recursion/compiler/src/ir/utils.rs @@ -130,6 +130,41 @@ impl Builder { result } + /// A version of `exp_reverse_bits_len` that uses the ExpReverseBitsLen precompile. + pub fn exp_reverse_bits_len_fast( + &mut self, + x: Felt, + power_bits: &Array>, + bit_len: impl Into>, + ) -> Felt { + // Instantiate an array of length one and store the value of x. + let mut x_copy_arr: Array> = self.dyn_array(1); + self.set(&mut x_copy_arr, 0, x); + // Get a pointer to the address holding x. + let x_copy_arr_ptr = match x_copy_arr { + Array::Dyn(ptr, _) => ptr, + _ => panic!("Expected a dynamic array"), + }; + + // Materialize the bit length as a Var. + let bit_len_var = bit_len.into().materialize(self); + // Get a pointer to the array of bits in the exponent. + let ptr = match power_bits { + Array::Dyn(ptr, _) => ptr, + _ => panic!("Expected a dynamic array"), + }; + + // Call the DslIR instruction ExpReverseBitsLen, which modifies the memory pointed to by `x_copy_arr_ptr`. + self.push(DslIr::ExpReverseBitsLen( + x_copy_arr_ptr, + ptr.address, + bit_len_var, + )); + + // Return the value stored at the address pointed to by `x_copy_arr_ptr`. + self.get(&x_copy_arr, 0) + } + /// Exponentiates a variable to a list of bits in little endian. pub fn exp_power_of_2_v( &mut self, diff --git a/recursion/core/Cargo.toml b/recursion/core/Cargo.toml index 4a227deb2a..9a6588ec5c 100644 --- a/recursion/core/Cargo.toml +++ b/recursion/core/Cargo.toml @@ -5,6 +5,7 @@ version = "0.1.0" [dependencies] p3-field = { workspace = true } +p3-util = { workspace = true } p3-baby-bear = { workspace = true } p3-air = { workspace = true } p3-matrix = { workspace = true } @@ -16,7 +17,7 @@ sp1-primitives = { path = "../../primitives" } tracing = "0.1.40" sp1-core = { path = "../../core" } hashbrown = "0.14.5" -itertools = "0.12.1" +itertools = "0.13.0" p3-bn254-fr = { workspace = true } p3-merkle-tree = { workspace = true } p3-commit = { workspace = true } diff --git a/recursion/core/src/air/builder.rs b/recursion/core/src/air/builder.rs index 6bcf20d408..ab6e6c1017 100644 --- a/recursion/core/src/air/builder.rs +++ b/recursion/core/src/air/builder.rs @@ -30,6 +30,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); + self.assert_bool(is_real.clone()); + let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); @@ -66,6 +68,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); + self.assert_bool(is_real.clone()); + let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); diff --git a/recursion/core/src/air/is_zero.rs b/recursion/core/src/air/is_zero.rs new file mode 100644 index 0000000000..6e92ef5dd1 --- /dev/null +++ b/recursion/core/src/air/is_zero.rs @@ -0,0 +1,83 @@ +//! An operation to check if the input is 0. +//! +//! This is guaranteed to return 1 if and only if the input is 0. +//! +//! The idea is that 1 - input * inverse is exactly the boolean value indicating whether the input +//! is 0. +use p3_air::AirBuilder; +use p3_field::AbstractField; +use p3_field::Field; +use sp1_derive::AlignedBorrow; + +use sp1_core::air::SP1AirBuilder; + +/// A set of columns needed to compute whether the given word is 0. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct IsZeroOperation { + /// The inverse of the input. + pub inverse: T, + + /// Result indicating whether the input is 0. This equals `inverse * input == 0`. + pub result: T, +} + +impl IsZeroOperation { + pub fn populate(&mut self, a: F) -> F { + let (inverse, result) = if a.is_zero() { + (F::zero(), F::one()) + } else { + (a.inverse(), F::zero()) + }; + + self.inverse = inverse; + self.result = result; + + let prod = inverse * a; + debug_assert!(prod == F::one() || prod.is_zero()); + + result + } +} + +impl IsZeroOperation { + pub fn eval( + builder: &mut AB, + a: AB::Expr, + cols: IsZeroOperation, + is_real: AB::Expr, + ) { + // Assert that the `is_real` is a boolean. + builder.assert_bool(is_real.clone()); + // Assert that the result is boolean. + builder.when(is_real.clone()).assert_bool(cols.result); + + // 1. Input == 0 => is_zero = 1 regardless of the inverse. + // 2. Input != 0 + // 2.1. inverse is correctly set => is_zero = 0. + // 2.2. inverse is incorrect + // 2.2.1 inverse is nonzero => is_zero isn't bool, it fails. + // 2.2.2 inverse is 0 => is_zero is 1. But then we would assert that a = 0. And that + // assert fails. + + // If the input is 0, then any product involving it is 0. If it is nonzero and its inverse + // is correctly set, then the product is 1. + + let one = AB::Expr::one(); + let inverse = cols.inverse; + + let is_zero = one.clone() - inverse * a.clone(); + + builder + .when(is_real.clone()) + .assert_eq(is_zero, cols.result); + + builder.when(is_real.clone()).assert_bool(cols.result); + + // If the result is 1, then the input is 0. + builder + .when(is_real.clone()) + .when(cols.result) + .assert_zero(a.clone()); + } +} diff --git a/recursion/core/src/air/mod.rs b/recursion/core/src/air/mod.rs index 9e118c6552..6fb0451170 100644 --- a/recursion/core/src/air/mod.rs +++ b/recursion/core/src/air/mod.rs @@ -2,6 +2,7 @@ mod block; mod builder; mod extension; mod is_ext_zero; +mod is_zero; mod multi_builder; mod public_values; @@ -9,5 +10,6 @@ pub use block::*; pub use builder::*; pub use extension::*; pub use is_ext_zero::*; +pub use is_zero::*; pub use multi_builder::*; pub use public_values::*; diff --git a/recursion/core/src/cpu/air/branch.rs b/recursion/core/src/cpu/air/branch.rs index 91bebfa65e..105dd771da 100644 --- a/recursion/core/src/cpu/air/branch.rs +++ b/recursion/core/src/cpu/air/branch.rs @@ -3,7 +3,9 @@ use p3_field::{AbstractField, Field}; use sp1_core::air::{BinomialExtension, ExtensionAirBuilder}; use crate::{ - air::{BinomialExtensionUtils, IsExtZeroOperation, SP1RecursionAirBuilder}, + air::{ + BinomialExtensionUtils, Block, BlockBuilder, IsExtZeroOperation, SP1RecursionAirBuilder, + }, cpu::{CpuChip, CpuCols}, memory::MemoryCols, }; @@ -22,18 +24,24 @@ impl CpuChip { let is_branch_instruction = self.is_branch_instruction::(local); let one = AB::Expr::one(); - // If the instruction is a BNEINC, verify that the a value is incremented by one. - builder - .when(local.is_real) - .when(local.selectors.is_bneinc) - .assert_eq(local.a.value()[0], local.a.prev_value()[0] + one.clone()); - // Convert operand values from Block to BinomialExtension. Note that it gets the // previous value of the `a` and `b` operands, since BNENIC will modify `a`. + let a_prev_ext: BinomialExtension = + BinomialExtensionUtils::from_block(local.a.prev_value().map(|x| x.into())); let a_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.a.value().map(|x| x.into())); let b_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.b.value().map(|x| x.into())); + let one_ext: BinomialExtension = + BinomialExtensionUtils::from_block(Block::from(one.clone())); + + let expected_a_ext = a_prev_ext + one_ext; + + // If the instruction is a BNEINC, verify that the a value is incremented by one. + builder + .when(local.is_real) + .when(local.selectors.is_bneinc) + .assert_block_eq(a_ext.as_block(), expected_a_ext.as_block()); let comparison_diff = a_ext - b_ext; diff --git a/recursion/core/src/cpu/air/jump.rs b/recursion/core/src/cpu/air/jump.rs index bf86a70cce..dd5e9b8bba 100644 --- a/recursion/core/src/cpu/air/jump.rs +++ b/recursion/core/src/cpu/air/jump.rs @@ -2,7 +2,7 @@ use p3_air::AirBuilder; use p3_field::{AbstractField, Field}; use crate::{ - air::SP1RecursionAirBuilder, + air::{Block, BlockBuilder, SP1RecursionAirBuilder}, cpu::{CpuChip, CpuCols}, memory::MemoryCols, runtime::STACK_SIZE, @@ -21,19 +21,29 @@ impl CpuChip { ) where AB: SP1RecursionAirBuilder, { + let is_jump_instr = self.is_jump_instruction::(local); + // Verify the next row's fp. builder .when_first_row() .assert_eq(local.fp, F::from_canonical_usize(STACK_SIZE)); - let not_jump_instruction = AB::Expr::one() - self.is_jump_instruction::(local); + let not_jump_instruction = AB::Expr::one() - is_jump_instr.clone(); let expected_next_fp = local.selectors.is_jal * (local.fp + local.c.value()[0]) - + local.selectors.is_jalr * local.a.value()[0] + + local.selectors.is_jalr * local.c.value()[0] + not_jump_instruction * local.fp; builder .when_transition() .when(next.is_real) .assert_eq(next.fp, expected_next_fp); + // Verify the a operand values. + let expected_a_val = local.selectors.is_jal * local.pc + + local.selectors.is_jalr * (local.pc + AB::Expr::one()); + let expected_a_val_block = Block::from(expected_a_val); + builder + .when(is_jump_instr) + .assert_block_eq(*local.a.value(), expected_a_val_block); + // Add to the `next_pc` expression. *next_pc += local.selectors.is_jal * (local.pc + local.b.value()[0]); *next_pc += local.selectors.is_jalr * local.b.value()[0]; diff --git a/recursion/core/src/cpu/air/memory.rs b/recursion/core/src/cpu/air/memory.rs index c0a3a2b639..d1b024130f 100644 --- a/recursion/core/src/cpu/air/memory.rs +++ b/recursion/core/src/cpu/air/memory.rs @@ -30,7 +30,7 @@ impl CpuChip { local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::Memory as u32), memory_cols.memory_addr, &memory_cols.memory, - is_memory_instr, + is_memory_instr.clone(), ); // Constraints on the memory column depending on load or store. @@ -41,7 +41,7 @@ impl CpuChip { ); // When there is a store, we ensure that we are writing the value of the a operand to the memory. builder - .when(local.selectors.is_store) + .when(is_memory_instr) .assert_block_eq(*local.a.value(), *memory_cols.memory.value()); } } diff --git a/recursion/core/src/cpu/air/mod.rs b/recursion/core/src/cpu/air/mod.rs index 28343367c8..b676d0bb25 100644 --- a/recursion/core/src/cpu/air/mod.rs +++ b/recursion/core/src/cpu/air/mod.rs @@ -72,7 +72,10 @@ where } // Constrain the syscalls. - let send_syscall = local.selectors.is_poseidon + local.selectors.is_fri_fold; + let send_syscall = local.selectors.is_poseidon + + local.selectors.is_fri_fold + + local.selectors.is_exp_reverse_bits_len; + let operands = [ local.clk.into(), local.a.value()[0].into(), @@ -118,7 +121,7 @@ impl CpuChip { builder .when_transition() .when(next.is_real) - .when_not(local.selectors.is_fri_fold) + .when_not(local.selectors.is_fri_fold + local.selectors.is_exp_reverse_bits_len) .assert_eq(local.clk.into() + AB::F::from_canonical_u32(4), next.clk); builder @@ -126,6 +129,12 @@ impl CpuChip { .when(next.is_real) .when(local.selectors.is_fri_fold) .assert_eq(local.clk.into() + local.a.value()[0], next.clk); + + builder + .when_transition() + .when(next.is_real) + .when(local.selectors.is_exp_reverse_bits_len) + .assert_eq(local.clk.into() + local.c.value()[0], next.clk); } /// Eval the is_real flag. diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs index 6fb0026948..1eeb094899 100644 --- a/recursion/core/src/cpu/columns/opcode.rs +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -44,7 +44,7 @@ pub struct OpcodeSelectorCols { pub is_fri_fold: T, pub is_commit: T, pub is_ext_to_felt: T, - + pub is_exp_reverse_bits_len: T, pub is_heap_expand: T, } @@ -70,6 +70,7 @@ impl OpcodeSelectorCols { Opcode::TRAP => self.is_trap = F::one(), Opcode::HALT => self.is_halt = F::one(), Opcode::FRIFold => self.is_fri_fold = F::one(), + Opcode::ExpReverseBitsLen => self.is_exp_reverse_bits_len = F::one(), Opcode::Poseidon2Compress => self.is_poseidon = F::one(), Opcode::Commit => self.is_commit = F::one(), Opcode::HintExt2Felt => self.is_ext_to_felt = F::one(), @@ -125,6 +126,7 @@ impl IntoIterator for &OpcodeSelectorCols { self.is_fri_fold, self.is_commit, self.is_ext_to_felt, + self.is_exp_reverse_bits_len, self.is_heap_expand, ] .into_iter() diff --git a/recursion/core/src/exp_reverse_bits/mod.rs b/recursion/core/src/exp_reverse_bits/mod.rs new file mode 100644 index 0000000000..790d745bcc --- /dev/null +++ b/recursion/core/src/exp_reverse_bits/mod.rs @@ -0,0 +1,518 @@ +#![allow(clippy::needless_range_loop)] + +use crate::air::{Block, IsZeroOperation, RecursionMemoryAirBuilder}; +use crate::memory::{MemoryReadSingleCols, MemoryReadWriteSingleCols}; +use crate::runtime::Opcode; +use core::borrow::Borrow; +use itertools::Itertools; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; +use p3_field::PrimeField32; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use p3_util::reverse_bits_len; +use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder}; +use sp1_core::utils::pad_rows_fixed; +use sp1_derive::AlignedBorrow; +use std::borrow::BorrowMut; +use tracing::instrument; + +use crate::air::SP1RecursionAirBuilder; +use crate::memory::MemoryRecord; +use crate::runtime::{ExecutionRecord, RecursionProgram}; + +pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::>(); + +#[derive(Default)] +pub struct ExpReverseBitsLenChip { + pub fixed_log2_rows: Option, + pub pad: bool, +} + +#[derive(Debug, Clone)] +pub struct ExpReverseBitsLenEvent { + /// The clk cycle for the event. + pub clk: F, + + /// Memory records to keep track of the value stored in the x parameter, and the current bit + /// of the exponent being scanned. + pub x: MemoryRecord, + pub current_bit: MemoryRecord, + + /// The length parameter of the function. + pub len: F, + + /// The previous accumulator value, needed to compute the current accumulator value. + pub prev_accum: F, + + /// The current accumulator value. + pub accum: F, + + /// A pointer to the memory address storing the exponent. + pub ptr: F, + + /// A pointer to the memory address storing the base. + pub base_ptr: F, + + /// Which step (in the range 0..len) of the computation we are in. + pub iteration_num: F, +} + +impl ExpReverseBitsLenEvent { + /// A way to construct a list of dummy events from input x and clk, used for testing. + pub fn dummy_from_input(x: F, exponent: u32, len: F, timestamp: F) -> Vec { + let mut events = Vec::new(); + let mut new_len = len; + let mut new_exponent = exponent; + let mut accum = F::one(); + + for i in 0..len.as_canonical_u32() { + let current_bit = new_exponent % 2; + let prev_accum = accum; + accum = prev_accum * prev_accum * if current_bit == 0 { F::one() } else { x }; + events.push(Self { + clk: timestamp + F::from_canonical_u32(i), + x: MemoryRecord::new_write( + F::one(), + Block::from([ + if i == len.as_canonical_u32() - 1 { + accum + } else { + x + }, + F::zero(), + F::zero(), + F::zero(), + ]), + timestamp + F::from_canonical_u32(i), + Block::from([x, F::zero(), F::zero(), F::zero()]), + timestamp + F::from_canonical_u32(i) - F::one(), + ), + current_bit: MemoryRecord::new_read( + F::zero(), + Block::from([ + F::from_canonical_u32(current_bit), + F::zero(), + F::zero(), + F::zero(), + ]), + timestamp + F::from_canonical_u32(i), + timestamp + F::from_canonical_u32(i) - F::one(), + ), + len: new_len, + prev_accum, + accum, + ptr: F::zero(), + base_ptr: F::one(), + iteration_num: F::from_canonical_u32(i), + }); + new_exponent /= 2; + new_len -= F::one(); + } + assert_eq!( + accum, + x.exp_u64(reverse_bits_len(exponent as usize, len.as_canonical_u32() as usize) as u64) + ); + events + } +} + +#[derive(AlignedBorrow, Debug, Clone, Copy)] +#[repr(C)] +pub struct ExpReverseBitsLenCols { + pub clk: T, + + /// The base of the exponentiation. + pub x: MemoryReadWriteSingleCols, + + /// The length parameter of the exponentiation. This is decremented by 1 every iteration. + pub len: T, + + /// The current bit of the exponent. This is read from memory. + pub current_bit: MemoryReadSingleCols, + + /// The previous accumulator squared. + pub prev_accum_squared: T, + + /// The accumulator of the current iteration. + pub accum: T, + + /// A flag column to check whether the current row represents the last iteration of the computation. + pub is_last: IsZeroOperation, + + /// A flag column to check whether the current row represents the first iteration of the computation. + pub is_first: IsZeroOperation, + + /// A column to count up from 0 to the length of the exponent. + pub iteration_num: T, + + /// A column which equals x if `current_bit` is on, and 1 otherwise. + pub multiplier: T, + + /// The memory address storing the exponent. + pub ptr: T, + + /// The memory address storing the base. + pub base_ptr: T, + + /// A flag column to check whether the base_ptr memory is accessed. Is equal to `is_first` OR + /// `is_last`. + pub x_mem_access_flag: T, + + pub is_real: T, +} + +impl BaseAir for ExpReverseBitsLenChip { + fn width(&self) -> usize { + NUM_EXP_REVERSE_BITS_LEN_COLS + } +} + +impl MachineAir for ExpReverseBitsLenChip { + type Record = ExecutionRecord; + + type Program = RecursionProgram; + + fn name(&self) -> String { + "ExpReverseBitsLen".to_string() + } + + fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) { + // This is a no-op. + } + + #[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))] + fn generate_trace( + &self, + input: &ExecutionRecord, + _: &mut ExecutionRecord, + ) -> RowMajorMatrix { + let mut rows = input + .exp_reverse_bits_len_events + .iter() + .map(|event| { + let mut row = [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS]; + + let cols: &mut ExpReverseBitsLenCols = row.as_mut_slice().borrow_mut(); + + cols.clk = event.clk; + + cols.x.populate(&event.x); + cols.current_bit.populate(&event.current_bit); + cols.len = event.len; + cols.accum = event.accum; + cols.prev_accum_squared = event.prev_accum * event.prev_accum; + cols.is_last.populate(F::one() - event.len); + cols.is_first.populate(event.iteration_num); + cols.is_real = F::one(); + cols.iteration_num = event.iteration_num; + cols.multiplier = if event.current_bit.value + == Block([F::one(), F::zero(), F::zero(), F::zero()]) + { + // The event may change the value stored in the x memory access, and we need to + // use the previous value. + event.x.prev_value[0] + } else { + F::one() + }; + cols.ptr = event.ptr; + cols.base_ptr = event.base_ptr; + cols.x_mem_access_flag = + F::from_bool(cols.len == F::one() || cols.iteration_num == F::zero()); + + row + }) + .collect_vec(); + + // Pad the trace to a power of two. + if self.pad { + pad_rows_fixed( + &mut rows, + || [F::zero(); NUM_EXP_REVERSE_BITS_LEN_COLS], + self.fixed_log2_rows, + ); + } + + // Convert the trace to a row major matrix. + let trace = RowMajorMatrix::new( + rows.into_iter().flatten().collect(), + NUM_EXP_REVERSE_BITS_LEN_COLS, + ); + + #[cfg(debug_assertions)] + println!( + "exp reverse bits len trace dims is width: {:?}, height: {:?}", + trace.width(), + trace.height() + ); + + trace + } + + fn included(&self, record: &Self::Record) -> bool { + !record.exp_reverse_bits_len_events.is_empty() + } +} + +impl ExpReverseBitsLenChip { + pub fn eval_exp_reverse_bits_len< + AB: BaseAirBuilder + ExtensionAirBuilder + RecursionMemoryAirBuilder + SP1AirBuilder, + >( + &self, + builder: &mut AB, + local: &ExpReverseBitsLenCols, + next: &ExpReverseBitsLenCols, + memory_access: AB::Var, + ) { + // Dummy constraints to normalize to DEGREE when DEGREE > 3. + if DEGREE > 3 { + let lhs = (0..DEGREE) + .map(|_| local.is_real.into()) + .product::(); + let rhs = (0..DEGREE) + .map(|_| local.is_real.into()) + .product::(); + builder.assert_eq(lhs, rhs); + } + + // Constraint that the operands are sent from the CPU table. + let operands = [ + local.clk.into(), + local.base_ptr.into(), + local.ptr.into(), + local.len.into(), + ]; + builder.receive_table( + Opcode::ExpReverseBitsLen.as_field::(), + &operands, + local.is_first.result, + ); + + IsZeroOperation::::eval( + builder, + AB::Expr::one() - local.len, + local.is_last, + local.is_real.into(), + ); + // Assert that the boolean columns are boolean. + builder.assert_bool(local.is_real); + + let current_bit_val = local.current_bit.access.value; + + // Probably redundant, but we assert here that the current bit value is boolean. + builder.assert_bool(current_bit_val); + + // Assert that `is_first` is on for the first row. + builder.when_first_row().assert_one(local.is_first.result); + + // Assert that the next row after a row for which `is_last` is on has `is_first` on. + builder + .when_transition() + .when(next.is_real * local.is_last.result) + .assert_one(next.is_first.result); + + // The accumulator needs to start with the multiplier for every `is_first` row. + builder + .when(local.is_first.result) + .assert_eq(local.accum, local.multiplier); + + // Assert that the last real row has `is_last` on. + builder + .when(local.is_real * (AB::Expr::one() - next.is_real)) + .assert_one(local.is_last.result); + + // `multiplier` is x if the current bit is 1, and 1 if the current bit is 0. + builder + .when(current_bit_val) + .assert_eq(local.multiplier, local.x.prev_value); + builder + .when(local.is_real) + .when_not(current_bit_val) + .assert_eq(local.multiplier, AB::Expr::one()); + + // To get `next.accum`, we multiply `local.prev_accum_squared` by `local.multiplier` when not + // `is_last`. + builder + .when_transition() + .when_not(local.is_last.result) + .assert_eq(local.accum, local.prev_accum_squared * local.multiplier); + + // Constrain the accum_squared column. + builder + .when_transition() + .when_not(local.is_last.result) + .assert_eq(next.prev_accum_squared, local.accum * local.accum); + + // Constrain the memory address `base_ptr` to be the same as the next, as long as not `is_last`. + builder + .when_transition() + .when_not(local.is_last.result) + .assert_eq(local.base_ptr, next.base_ptr); + + // The `len` counter must decrement when not `is_last`. + builder + .when_transition() + .when(local.is_real) + .when_not(local.is_last.result) + .assert_eq(local.len, next.len + AB::Expr::one()); + + // The `iteration_num` counter must increment when not `is_last`. + builder + .when_transition() + .when(local.is_real) + .when_not(local.is_last.result) + .assert_eq(local.iteration_num + AB::Expr::one(), next.iteration_num); + + // The `iteration_num` counter must be 0 iff `is_first` is on. + builder + .when(local.is_first.result) + .assert_eq(local.iteration_num, AB::Expr::zero()); + + // Access the memory for current_bit. + builder.recursion_eval_memory_access_single( + local.clk, + local.ptr, + &local.current_bit, + memory_access, + ); + + // Constrain that the x_mem_access_flag is true when `is_first` or `is_last`. + builder.when(local.is_real).assert_eq( + local.x_mem_access_flag, + local.is_first.result + local.is_last.result + - local.is_first.result * local.is_last.result, + ); + + // Access the memory for x. + // This only needs to be done for the first and last iterations. + builder.recursion_eval_memory_access_single( + local.clk, + local.base_ptr, + &local.x, + local.x_mem_access_flag, + ); + + // The `base_ptr` column stays the same when not `is_last`. + builder + .when_transition() + .when(next.is_real) + .when_not(local.is_last.result) + .assert_eq(next.base_ptr, local.base_ptr); + + // Ensure sequential `clk` values. + builder + .when_transition() + .when_not(local.is_last.result) + .when(next.is_real) + .assert_eq(local.clk + AB::Expr::one(), next.clk); + + // Ensure that the value at the x memory access is unchanged when not `is_last`. + builder + .when_not(local.is_last.result) + .assert_eq(local.x.access.value, local.x.prev_value); + + // Ensure that the value at the x memory access is `accum` when `is_last`. + builder + .when(local.is_last.result) + .assert_eq(local.accum, local.x.access.value); + } + + pub const fn do_exp_bit_memory_access(local: &ExpReverseBitsLenCols) -> T { + local.is_real + } +} + +impl Air for ExpReverseBitsLenChip +where + AB: SP1RecursionAirBuilder, +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local: &ExpReverseBitsLenCols = (*local).borrow(); + let next: &ExpReverseBitsLenCols = (*next).borrow(); + self.eval_exp_reverse_bits_len::( + builder, + local, + next, + Self::do_exp_bit_memory_access::(local), + ); + } +} + +#[cfg(test)] +mod tests { + use itertools::Itertools; + use std::time::Instant; + + use p3_baby_bear::BabyBear; + use p3_baby_bear::DiffusionMatrixBabyBear; + use p3_field::AbstractField; + use p3_matrix::{dense::RowMajorMatrix, Matrix}; + use p3_poseidon2::Poseidon2; + use p3_poseidon2::Poseidon2ExternalMatrixGeneral; + use sp1_core::stark::StarkGenericConfig; + use sp1_core::{ + air::MachineAir, + utils::{uni_stark_prove, uni_stark_verify, BabyBearPoseidon2}, + }; + + use crate::exp_reverse_bits::ExpReverseBitsLenChip; + use crate::exp_reverse_bits::ExpReverseBitsLenEvent; + use crate::runtime::ExecutionRecord; + + #[test] + fn prove_babybear() { + let config = BabyBearPoseidon2::compressed(); + let mut challenger = config.challenger(); + + let chip = ExpReverseBitsLenChip::<5> { + pad: true, + fixed_log2_rows: None, + }; + + let test_xs = (1..16).map(BabyBear::from_canonical_u32).collect_vec(); + + let test_exponents = (1..16).collect_vec(); + + let mut input_exec = ExecutionRecord::::default(); + for (x, exponent) in test_xs.into_iter().zip_eq(test_exponents) { + let mut events = ExpReverseBitsLenEvent::dummy_from_input( + x, + exponent, + BabyBear::from_canonical_u32(exponent.ilog2() + 1), + x, + ); + input_exec.exp_reverse_bits_len_events.append(&mut events); + } + println!( + "input exec: {:?}", + input_exec.exp_reverse_bits_len_events.len() + ); + let trace: RowMajorMatrix = + chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); + println!( + "trace dims is width: {:?}, height: {:?}", + trace.width(), + trace.height() + ); + + let start = Instant::now(); + let proof = uni_stark_prove(&config, &chip, &mut challenger, trace); + let duration = start.elapsed().as_secs_f64(); + println!("proof duration = {:?}", duration); + + let mut challenger: p3_challenger::DuplexChallenger< + BabyBear, + Poseidon2, + 16, + 8, + > = config.challenger(); + let start = Instant::now(); + uni_stark_verify(&config, &chip, &mut challenger, &proof) + .expect("expected proof to be valid"); + + let duration = start.elapsed().as_secs_f64(); + println!("verify duration = {:?}", duration); + } +} diff --git a/recursion/core/src/fri_fold/mod.rs b/recursion/core/src/fri_fold/mod.rs index d59668d1e9..7d68c81216 100644 --- a/recursion/core/src/fri_fold/mod.rs +++ b/recursion/core/src/fri_fold/mod.rs @@ -25,6 +25,7 @@ pub const NUM_FRI_FOLD_COLS: usize = core::mem::size_of::>(); #[derive(Default)] pub struct FriFoldChip { pub fixed_log2_rows: Option, + pub pad: bool, } #[derive(Debug, Clone)] @@ -143,11 +144,13 @@ impl MachineAir for FriFoldChip .collect_vec(); // Pad the trace to a power of two. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_FRI_FOLD_COLS], - self.fixed_log2_rows, - ); + if self.pad { + pad_rows_fixed( + &mut rows, + || [F::zero(); NUM_FRI_FOLD_COLS], + self.fixed_log2_rows, + ); + } // Convert the trace to a row major matrix. let trace = RowMajorMatrix::new(rows.into_iter().flatten().collect(), NUM_FRI_FOLD_COLS); @@ -211,7 +214,7 @@ impl FriFoldChip { .when(next.is_real) .assert_zero(next.m); - // Ensure that all rows for a FRI FOLD invocation have the same input_ptr, clk, and sequential m values. + // Ensure that all rows for a FRI FOLD invocation have the same input_ptr and sequential clk and m values. builder .when_transition() .when_not(local.is_last_iteration) diff --git a/recursion/core/src/lib.rs b/recursion/core/src/lib.rs index db37f40f6c..785179fa77 100644 --- a/recursion/core/src/lib.rs +++ b/recursion/core/src/lib.rs @@ -1,5 +1,6 @@ pub mod air; pub mod cpu; +pub mod exp_reverse_bits; pub mod fri_fold; pub mod memory; pub mod multi; diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 23173f7de9..0b93f52aa2 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -66,8 +66,14 @@ impl MachineAir for MultiChip { input: &ExecutionRecord, output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let fri_fold_chip = FriFoldChip::<3>::default(); - let poseidon2 = Poseidon2Chip::default(); + let fri_fold_chip = FriFoldChip::<3> { + fixed_log2_rows: None, + pad: false, + }; + let poseidon2 = Poseidon2Chip { + fixed_log2_rows: None, + pad: false, + }; let fri_fold_trace = fri_fold_chip.generate_trace(input, output); let mut poseidon2_trace = poseidon2.generate_trace(input, output); @@ -145,14 +151,12 @@ where // all the fri fold rows are first, then the posiedon2 rows, and finally any padded (non-real) rows. // First verify that all real rows are contiguous. - builder.when_first_row().assert_one(local_is_real.clone()); builder .when_transition() .when_not(local_is_real.clone()) .assert_zero(next_is_real.clone()); // Next, verify that all fri fold rows are before the poseidon2 rows within the real rows section. - builder.when_first_row().assert_one(local.is_fri_fold); builder .when_transition() .when(next_is_real) @@ -190,8 +194,7 @@ where local.poseidon2_receive_table, ); sub_builder.assert_eq( - local.is_poseidon2 - * Poseidon2Chip::do_memory_access::(poseidon2_columns), + local.is_poseidon2 * Poseidon2Chip::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); @@ -201,7 +204,7 @@ where local.poseidon2(), next.poseidon2(), local.poseidon2_receive_table, - local.poseidon2_memory_access.into(), + local.poseidon2_memory_access, ); } } diff --git a/recursion/core/src/poseidon2/columns.rs b/recursion/core/src/poseidon2/columns.rs index 12fa730477..fa12a655f2 100644 --- a/recursion/core/src/poseidon2/columns.rs +++ b/recursion/core/src/poseidon2/columns.rs @@ -11,7 +11,10 @@ pub struct Poseidon2Cols { pub left_input: T, pub right_input: T, pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output + pub do_receive: T, + pub do_memory: T, pub round_specific_cols: RoundSpecificCols, + pub is_real: T, } #[derive(AlignedBorrow, Clone, Copy)] @@ -45,6 +48,7 @@ impl RoundSpecificCols { pub struct ComputationCols { pub input: [T; WIDTH], pub add_rc: [T; WIDTH], + pub sbox_deg_3: [T; WIDTH], pub sbox_deg_7: [T; WIDTH], pub output: [T; WIDTH], } diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index c871bd873d..21f56edb0b 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -6,7 +6,6 @@ use p3_field::AbstractField; use p3_matrix::Matrix; use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, SP1AirBuilder}; use sp1_primitives::RC_16_30_U32; -use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; use crate::memory::MemoryCols; @@ -25,6 +24,7 @@ pub const WIDTH: usize = 16; #[derive(Default)] pub struct Poseidon2Chip { pub fixed_log2_rows: Option, + pub pad: bool, } impl BaseAir for Poseidon2Chip { @@ -40,7 +40,7 @@ impl Poseidon2Chip { local: &Poseidon2Cols, next: &Poseidon2Cols, receive_table: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { const NUM_ROUNDS_F: usize = 8; const NUM_ROUNDS_P: usize = 13; @@ -66,6 +66,10 @@ impl Poseidon2Chip { .sum::(); let is_memory_write = local.rounds[local.rounds.len() - 1]; + self.eval_control_flow_and_inputs(builder, local, next); + + self.eval_syscall(builder, local, receive_table); + self.eval_mem( builder, local, @@ -84,16 +88,71 @@ impl Poseidon2Chip { is_internal_layer.clone(), NUM_ROUNDS_F + NUM_ROUNDS_P + 1, ); + } - self.eval_syscall(builder, local, receive_table); - - // Range check all flags. - for i in 0..local.rounds.len() { + fn eval_control_flow_and_inputs( + &self, + builder: &mut AB, + local: &Poseidon2Cols, + next: &Poseidon2Cols, + ) { + let num_total_rounds = local.rounds.len(); + for i in 0..num_total_rounds { + // Verify that the round flags are correct. builder.assert_bool(local.rounds[i]); + + // Assert that the next round is correct. + builder + .when_transition() + .assert_eq(local.rounds[i], next.rounds[(i + 1) % num_total_rounds]); + + if i != num_total_rounds - 1 { + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.dst_input, next.dst_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.left_input, next.left_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.right_input, next.right_input); + } } - builder.assert_bool( - is_memory_read + is_initial + is_external_layer + is_internal_layer + is_memory_write, + + // Ensure that at most one of the round flags is set. + let round_acc = local + .rounds + .iter() + .fold(AB::Expr::zero(), |acc, round_flag| acc + *round_flag); + builder.assert_bool(round_acc); + + // Verify the do_memory flag. + builder.assert_eq( + local.do_memory, + local.is_real * (local.rounds[0] + local.rounds[23]), ); + + // Verify the do_receive flag. + builder.assert_eq(local.do_receive, local.is_real * local.rounds[0]); + + // Verify the first row starts at round 0. + builder.when_first_row().assert_one(local.rounds[0]); + // The round count is not a power of 2, so the last row should not be real. + builder.when_last_row().assert_zero(local.is_real); + + // Verify that all is_real flags within a round are equal. + let is_last_round = local.rounds[23]; + builder + .when_transition() + .when_not(is_last_round) + .assert_eq(local.is_real, next.is_real); } fn eval_mem( @@ -103,20 +162,23 @@ impl Poseidon2Chip { next: &Poseidon2Cols, is_memory_read: AB::Var, is_memory_write: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { let memory_access_cols = local.round_specific_cols.memory_access(); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.left_input, memory_access_cols.addr_first_half); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.right_input, memory_access_cols.addr_second_half); builder + .when(local.is_real) .when(is_memory_write) .assert_eq(local.dst_input, memory_access_cols.addr_first_half); - builder.when(is_memory_write).assert_eq( + builder.when(local.is_real).when(is_memory_write).assert_eq( local.dst_input + AB::F::from_canonical_usize(WIDTH / 2), memory_access_cols.addr_second_half, ); @@ -131,7 +193,11 @@ impl Poseidon2Chip { local.clk + AB::Expr::one() * is_memory_write, addr, &memory_access_cols.mem_access[i], - memory_access.clone(), + memory_access, + ); + builder.when(local.is_real).when(is_memory_read).assert_eq( + *memory_access_cols.mem_access[i].value(), + *memory_access_cols.mem_access[i].prev_value(), ); } @@ -139,10 +205,14 @@ impl Poseidon2Chip { // computation round. let next_computation_col = next.round_specific_cols.computation(); for i in 0..WIDTH { - builder.when_transition().when(is_memory_read).assert_eq( - *memory_access_cols.mem_access[i].value(), - next_computation_col.input[i], - ); + builder + .when_transition() + .when(local.is_real) + .when(is_memory_read) + .assert_eq( + *memory_access_cols.mem_access[i].value(), + next_computation_col.input[i], + ); } } @@ -184,6 +254,7 @@ impl Poseidon2Chip { } } builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(result, computation_cols.add_rc[i]); } @@ -196,8 +267,15 @@ impl Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[i] * computation_cols.add_rc[i] * computation_cols.add_rc[i]; - let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * computation_cols.add_rc[i]; builder + .when(local.is_real) + .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) + .assert_eq(computation_cols.sbox_deg_3[i], sbox_deg_3); + let sbox_deg_7 = computation_cols.sbox_deg_3[i] + * computation_cols.sbox_deg_3[i] + * computation_cols.add_rc[i]; + builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(sbox_deg_7, computation_cols.sbox_deg_7[i]); } @@ -253,6 +331,7 @@ impl Poseidon2Chip { for i in 0..WIDTH { state[i] += sums[i % 4].clone(); builder + .when(local.is_real) .when(is_external_layer.clone() + is_initial.clone()) .assert_eq(state[i].clone(), computation_cols.output[i]); } @@ -264,6 +343,7 @@ impl Poseidon2Chip { let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); internal_linear_layer(&mut state); builder + .when(local.is_real) .when(is_internal_layer.clone()) .assert_all_eq(state.clone(), computation_cols.output); } @@ -281,6 +361,7 @@ impl Poseidon2Chip { builder .when_transition() + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(computation_cols.output[i], next_round_value); } @@ -307,13 +388,11 @@ impl Poseidon2Chip { } pub const fn do_receive_table(local: &Poseidon2Cols) -> T { - local.rounds[0] + local.do_receive } - pub fn do_memory_access, Output>( - local: &Poseidon2Cols, - ) -> Output { - local.rounds[0] + local.rounds[23] + pub fn do_memory_access(local: &Poseidon2Cols) -> T { + local.do_memory } } @@ -333,7 +412,7 @@ where local, next, Self::do_receive_table::(local), - Self::do_memory_access::(local), + Self::do_memory_access::(local), ); } } @@ -371,6 +450,7 @@ mod tests { fn generate_trace() { let chip = Poseidon2Chip { fixed_log2_rows: None, + pad: true, }; let rng = &mut rand::thread_rng(); @@ -420,6 +500,7 @@ mod tests { let chip = Poseidon2Chip { fixed_log2_rows: None, + pad: true, }; let trace: RowMajorMatrix = chip.generate_trace(&input_exec, &mut ExecutionRecord::::default()); diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index cc6a41d94f..567c09fc7d 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -49,7 +49,9 @@ impl MachineAir for Poseidon2Chip { for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.is_real = F::one(); + let is_receive = r == 0; let is_memory_read = r == 0; let is_initial_layer = r == 1; let is_external_layer = @@ -78,6 +80,10 @@ impl MachineAir for Poseidon2Chip { cols.right_input = poseidon2_event.right; cols.rounds[r] = F::one(); + if is_receive { + cols.do_receive = F::one(); + } + if is_memory_read || is_memory_write { let memory_access_cols = cols.round_specific_cols.memory_access_mut(); @@ -97,6 +103,7 @@ impl MachineAir for Poseidon2Chip { .populate(&poseidon2_event.result_records[i]); } } + cols.do_memory = F::one(); } else { let computation_cols = cols.round_specific_cols.computation_mut(); @@ -131,6 +138,7 @@ impl MachineAir for Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[j] * computation_cols.add_rc[j] * computation_cols.add_rc[j]; + computation_cols.sbox_deg_3[j] = sbox_deg_3; computation_cols.sbox_deg_7[j] = sbox_deg_3 * sbox_deg_3 * computation_cols.add_rc[j]; } @@ -163,12 +171,24 @@ impl MachineAir for Poseidon2Chip { } } + let num_real_rows = rows.len(); + // Pad the trace to a power of two. - pad_rows_fixed( - &mut rows, - || [F::zero(); NUM_POSEIDON2_COLS], - self.fixed_log2_rows, - ); + if self.pad { + pad_rows_fixed( + &mut rows, + || [F::zero(); NUM_POSEIDON2_COLS], + self.fixed_log2_rows, + ); + } + + let mut round_num = 0; + for row in rows[num_real_rows..].iter_mut() { + let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.rounds[round_num] = F::one(); + + round_num = (round_num + 1) % rounds; + } // Convert the trace to a row major matrix. RowMajorMatrix::new( diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 1935c9cd69..bc3fd27913 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -22,6 +22,7 @@ pub use utils::*; use crate::air::{Block, RECURSION_PUBLIC_VALUES_COL_MAP, RECURSIVE_PROOF_NUM_PV_ELTS}; use crate::cpu::CpuEvent; +use crate::exp_reverse_bits::ExpReverseBitsLenEvent; use crate::fri_fold::FriFoldEvent; use crate::memory::MemoryRecord; use crate::poseidon2::Poseidon2Event; @@ -588,12 +589,11 @@ where .resize(RECURSIVE_PROOF_NUM_PV_ELTS, F::zero()); self.record.public_values[RECURSION_PUBLIC_VALUES_COL_MAP.exit_code] = F::one(); - let (a_val, b_val, c_val) = self.all_rr(&instruction); let trap_pc = self.pc.as_canonical_u32() as usize; let trace = self.program.traces[trap_pc].clone(); if let Some(mut trace) = trace { trace.resolve(); - eprintln!("TRAP encountered. Backtrace:\n{:?}", trace); + panic!("TRAP encountered. Backtrace:\n{:?}", trace); } else { for nearby_pc in (0..trap_pc).rev() { let trace = self.program.traces[nearby_pc].clone(); @@ -606,9 +606,8 @@ where exit(1); } } - eprintln!("TRAP encountered. No backtrace available"); + panic!("TRAP encountered. No backtrace available"); } - (a, b, c) = (a_val, b_val, c_val); } Opcode::HALT => { self.record @@ -822,6 +821,74 @@ where next_clk = timestamp; (a, b, c) = (a_val, b_val, c_val); } + Opcode::ExpReverseBitsLen => { + // Read the operands. + let (a_val, b_val, c_val) = self.all_rr(&instruction); + + // A pointer to the base of the exponentiation. + let base = a_val[0]; + + // A pointer to the first bit (LSB) of the exponent. + let input_ptr = b_val[0]; + + // The length parameter in bit-reverse-len. + let len = c_val[0]; + + let mut timestamp = self.clk; + + let mut accum = F::one(); + + // Read the value at the pointer `base`. + let mut x_record = self.mr(base, timestamp).0; + + // Iterate over the `len` least-significant bits of the exponent. + for m in 0..len.as_canonical_u32() { + let m = F::from_canonical_u32(m); + + // Pointer to the current bit. + let ptr = input_ptr + m; + + // Read the current bit. + let (current_bit_record, current_bit) = self.mr(ptr, timestamp); + let current_bit = current_bit.ext::().as_base_slice()[0]; + + // Extract the val in `x_record` + let current_x_val = x_record.value[0]; + + let prev_accum = accum; + accum = prev_accum + * prev_accum + * if current_bit == F::one() { + current_x_val + } else { + F::one() + }; + + // On the last iteration, write accum to the address pointed to in `base`. + if m == len - F::one() { + x_record = self.mw(base, Block::from(accum), timestamp); + }; + + // Add the event for this iteration to the `ExecutionRecord`. + self.record + .exp_reverse_bits_len_events + .push(ExpReverseBitsLenEvent { + clk: timestamp, + x: x_record, + current_bit: current_bit_record, + len: len - m, + prev_accum, + accum, + ptr, + base_ptr: base, + iteration_num: m, + }); + timestamp += F::one(); + } + + next_clk = timestamp; + (a, b, c) = (a_val, b_val, c_val); + } // For both the Commit and RegisterPublicValue opcodes, we record the public value Opcode::Commit | Opcode::RegisterPublicValue => { let (a_val, b_val, c_val) = self.all_rr(&instruction); diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index 8c373ca201..d6db8abc13 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -50,6 +50,7 @@ pub enum Opcode { RegisterPublicValue = 42, LessThanF = 43, CycleTracker = 44, + ExpReverseBitsLen = 45, } impl Opcode { diff --git a/recursion/core/src/runtime/record.rs b/recursion/core/src/runtime/record.rs index 80f6d7b321..c6c0217bf3 100644 --- a/recursion/core/src/runtime/record.rs +++ b/recursion/core/src/runtime/record.rs @@ -9,6 +9,7 @@ use std::collections::HashMap; use super::RecursionProgram; use crate::air::Block; use crate::cpu::CpuEvent; +use crate::exp_reverse_bits::ExpReverseBitsLenEvent; use crate::fri_fold::FriFoldEvent; use crate::poseidon2::Poseidon2Event; use crate::range_check::RangeCheckEvent; @@ -20,7 +21,7 @@ pub struct ExecutionRecord { pub poseidon2_events: Vec>, pub fri_fold_events: Vec>, pub range_check_events: BTreeMap, - + pub exp_reverse_bits_len_events: Vec>, // (address, value) pub first_memory_record: Vec<(F, Block)>, @@ -57,6 +58,10 @@ impl MachineRecord for ExecutionRecord { "range_check_events".to_string(), self.range_check_events.len(), ); + stats.insert( + "exp_reverse_bits_len_events".to_string(), + self.exp_reverse_bits_len_events.len(), + ); stats } diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 540dcb613b..8f24874f5d 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -3,9 +3,9 @@ pub mod poseidon2; pub mod utils; use crate::{ - cpu::CpuChip, fri_fold::FriFoldChip, memory::MemoryGlobalChip, multi::MultiChip, - poseidon2::Poseidon2Chip, poseidon2_wide::Poseidon2WideChip, program::ProgramChip, - range_check::RangeCheckChip, + cpu::CpuChip, exp_reverse_bits::ExpReverseBitsLenChip, fri_fold::FriFoldChip, + memory::MemoryGlobalChip, multi::MultiChip, poseidon2::Poseidon2Chip, + poseidon2_wide::Poseidon2WideChip, program::ProgramChip, range_check::RangeCheckChip, }; use core::iter::once; use p3_field::{extension::BinomiallyExtendable, PrimeField32}; @@ -32,6 +32,7 @@ pub enum RecursionAir, const DEGREE: u FriFold(FriFoldChip), RangeCheck(RangeCheckChip), Multi(MultiChip), + ExpReverseBitsLen(ExpReverseBitsLenChip), } impl, const DEGREE: usize> RecursionAir { @@ -78,8 +79,15 @@ impl, const DEGREE: usize> RecursionAi }))) .chain(once(RecursionAir::FriFold(FriFoldChip:: { fixed_log2_rows: None, + pad: true, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) + .chain(once(RecursionAir::ExpReverseBitsLen( + ExpReverseBitsLenChip:: { + fixed_log2_rows: None, + pad: true, + }, + ))) .collect() } @@ -96,6 +104,12 @@ impl, const DEGREE: usize> RecursionAi fixed_log2_rows: None, }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) + .chain(once(RecursionAir::ExpReverseBitsLen( + ExpReverseBitsLenChip:: { + fixed_log2_rows: None, + pad: true, + }, + ))) .collect() } @@ -109,9 +123,15 @@ impl, const DEGREE: usize> RecursionAi fixed_log2_rows: Some(19), }))) .chain(once(RecursionAir::Multi(MultiChip { - fixed_log2_rows: Some(20), + fixed_log2_rows: Some(19), }))) .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default()))) + .chain(once(RecursionAir::ExpReverseBitsLen( + ExpReverseBitsLenChip:: { + fixed_log2_rows: None, + pad: true, + }, + ))) .collect() } } diff --git a/recursion/gnark-cli/Cargo.toml b/recursion/gnark-cli/Cargo.toml new file mode 100644 index 0000000000..cd6fe2647c --- /dev/null +++ b/recursion/gnark-cli/Cargo.toml @@ -0,0 +1,10 @@ +[workspace] +[package] +name = "sp1-recursion-gnark-cli" +version = "0.1.0" +edition = "2021" + +[dependencies] +sp1-recursion-gnark-ffi = { path = "../gnark-ffi", features = ["native"] } +clap = { version = "4.3.8", features = ["derive"] } +bincode = "1.3.3" diff --git a/recursion/gnark-cli/src/main.rs b/recursion/gnark-cli/src/main.rs new file mode 100644 index 0000000000..61f5e0602a --- /dev/null +++ b/recursion/gnark-cli/src/main.rs @@ -0,0 +1,97 @@ +//! A simple CLI that wraps the gnark-ffi crate. This is called using Docker in gnark-ffi when the +//! native feature is disabled. + +use sp1_recursion_gnark_ffi::ffi::{ + build_plonk_bn254, prove_plonk_bn254, test_plonk_bn254, verify_plonk_bn254, +}; + +use clap::{Args, Parser, Subcommand}; +use std::{ + fs::File, + io::{read_to_string, Write}, +}; + +#[derive(Debug, Parser)] +struct Cli { + #[command(subcommand)] + command: Command, +} + +#[allow(clippy::enum_variant_names)] +#[derive(Debug, Subcommand)] +enum Command { + BuildPlonk(BuildArgs), + ProvePlonk(ProveArgs), + VerifyPlonk(VerifyArgs), + TestPlonk(TestArgs), +} + +#[derive(Debug, Args)] +struct BuildArgs { + data_dir: String, +} + +#[derive(Debug, Args)] +struct ProveArgs { + data_dir: String, + witness_path: String, + output_path: String, +} + +#[derive(Debug, Args)] +struct VerifyArgs { + data_dir: String, + proof_path: String, + vkey_hash: String, + committed_values_digest: String, + output_path: String, +} + +#[derive(Debug, Args)] +struct TestArgs { + witness_json: String, + constraints_json: String, +} + +fn run_build(args: BuildArgs) { + build_plonk_bn254(&args.data_dir); +} + +fn run_prove(args: ProveArgs) { + let proof = prove_plonk_bn254(&args.data_dir, &args.witness_path); + let mut file = File::create(&args.output_path).unwrap(); + bincode::serialize_into(&mut file, &proof).unwrap(); +} + +fn run_verify(args: VerifyArgs) { + // For proof, we read the string from file since it can be large. + let file = File::open(&args.proof_path).unwrap(); + let proof = read_to_string(file).unwrap(); + let result = verify_plonk_bn254( + &args.data_dir, + proof.trim(), + &args.vkey_hash, + &args.committed_values_digest, + ); + let output = match result { + Ok(_) => "OK".to_string(), + Err(e) => e, + }; + let mut file = File::create(&args.output_path).unwrap(); + file.write_all(output.as_bytes()).unwrap(); +} + +fn run_test(args: TestArgs) { + test_plonk_bn254(&args.witness_json, &args.constraints_json); +} + +fn main() { + let cli = Cli::parse(); + + match cli.command { + Command::BuildPlonk(args) => run_build(args), + Command::ProvePlonk(args) => run_prove(args), + Command::VerifyPlonk(args) => run_verify(args), + Command::TestPlonk(args) => run_test(args), + } +} diff --git a/recursion/gnark-ffi/Cargo.toml b/recursion/gnark-ffi/Cargo.toml index d10ed2fd31..7ac3cea14f 100644 --- a/recursion/gnark-ffi/Cargo.toml +++ b/recursion/gnark-ffi/Cargo.toml @@ -5,8 +5,10 @@ edition = "2021" [dependencies] p3-field = { workspace = true } +p3-symmetric = { workspace = true } p3-baby-bear = { workspace = true } sp1-recursion-compiler = { path = "../compiler" } +sp1-core = { path = "../../core" } serde = "1.0.201" serde_json = "1.0.117" tempfile = "3.10.1" @@ -14,6 +16,10 @@ rand = "0.8" log = "0.4.21" num-bigint = "0.4.5" cfg-if = "1.0" +bincode = "1.3.3" +anyhow = "1.0.86" +sha2 = "0.10.8" +hex = "0.4.3" [build-dependencies] bindgen = "0.69.4" @@ -21,4 +27,4 @@ cc = "1.0" cfg-if = "1.0" [features] -plonk = [] +native = [] diff --git a/recursion/gnark-ffi/assets/ISP1Verifier.txt b/recursion/gnark-ffi/assets/ISP1Verifier.txt index 4298f4b8cf..1ea8727937 100644 --- a/recursion/gnark-ffi/assets/ISP1Verifier.txt +++ b/recursion/gnark-ffi/assets/ISP1Verifier.txt @@ -1,20 +1,23 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; +pragma solidity ^0.8.19; /// @title SP1 Verifier Interface /// @author Succinct Labs /// @notice This contract is the interface for the SP1 Verifier. interface ISP1Verifier { - /// @notice Returns the version of the SP1 Verifier. + /// @notice Returns the version of SP1 this verifier corresponds to. function VERSION() external pure returns (string memory); + /// @notice Returns the hash of the verification key. + function VKEY_HASH() external pure returns (bytes32); + /// @notice Verifies a proof with given public values and vkey. /// @param vkey The verification key for the RISC-V program. /// @param publicValues The public values encoded as bytes. /// @param proofBytes The proof of the program execution the SP1 zkVM encoded as bytes. function verifyProof( bytes32 vkey, - bytes memory publicValues, - bytes memory proofBytes + bytes calldata publicValues, + bytes calldata proofBytes ) external view; } diff --git a/recursion/gnark-ffi/assets/SP1MockVerifier.txt b/recursion/gnark-ffi/assets/SP1MockVerifier.txt index bf938718b6..0ca6f4a1fb 100644 --- a/recursion/gnark-ffi/assets/SP1MockVerifier.txt +++ b/recursion/gnark-ffi/assets/SP1MockVerifier.txt @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; +pragma solidity ^0.8.19; import {ISP1Verifier} from "./ISP1Verifier.sol"; diff --git a/recursion/gnark-ffi/assets/SP1Verifier.txt b/recursion/gnark-ffi/assets/SP1Verifier.txt index f7389163e8..2cc044bfb5 100644 --- a/recursion/gnark-ffi/assets/SP1Verifier.txt +++ b/recursion/gnark-ffi/assets/SP1Verifier.txt @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -pragma solidity ^0.8.25; +pragma solidity ^0.8.19; import {ISP1Verifier} from "./ISP1Verifier.sol"; import {PlonkVerifier} from "./PlonkVerifier.sol"; @@ -8,14 +8,20 @@ import {PlonkVerifier} from "./PlonkVerifier.sol"; /// @author Succinct Labs /// @notice This contracts implements a solidity verifier for SP1. contract SP1Verifier is PlonkVerifier { + error WrongVersionProof(); + function VERSION() external pure returns (string memory) { - return "TODO"; + return "{SP1_CIRCUIT_VERSION}"; + } + + function VKEY_HASH() public pure returns (bytes32) { + return {VKEY_HASH}; } /// @notice Hashes the public values to a field elements inside Bn254. /// @param publicValues The public values. function hashPublicValues( - bytes memory publicValues + bytes calldata publicValues ) public pure returns (bytes32) { return sha256(publicValues) & bytes32(uint256((1 << 253) - 1)); } @@ -26,13 +32,20 @@ contract SP1Verifier is PlonkVerifier { /// @param proofBytes The proof of the program execution the SP1 zkVM encoded as bytes. function verifyProof( bytes32 vkey, - bytes memory publicValues, - bytes memory proofBytes + bytes calldata publicValues, + bytes calldata proofBytes ) public view { + // To ensure the proof corresponds to this verifier, we check that the first 4 bytes of + // proofBytes match the first 4 bytes of VKEY_HASH. + bytes4 proofBytesPrefix = bytes4(proofBytes[:4]); + if (proofBytesPrefix != bytes4(VKEY_HASH())) { + revert WrongVersionProof(); + } + bytes32 publicValuesDigest = hashPublicValues(publicValues); uint256[] memory inputs = new uint256[](2); inputs[0] = uint256(vkey); inputs[1] = uint256(publicValuesDigest); - this.Verify(proofBytes, inputs); + this.Verify(proofBytes[4:], inputs); } } \ No newline at end of file diff --git a/recursion/gnark-ffi/build.rs b/recursion/gnark-ffi/build.rs index 650744599c..0aa355524a 100644 --- a/recursion/gnark-ffi/build.rs +++ b/recursion/gnark-ffi/build.rs @@ -11,7 +11,7 @@ use bindgen::CargoCallbacks; /// Build the go library, generate Rust bindings for the exposed functions, and link the library. fn main() { cfg_if! { - if #[cfg(feature = "plonk")] { + if #[cfg(feature = "native")] { println!("cargo:rerun-if-changed=go"); // Define the output directory let out_dir = env::var("OUT_DIR").unwrap(); diff --git a/recursion/gnark-ffi/go/main.go b/recursion/gnark-ffi/go/main.go index 89bba4a7e8..ed782400f2 100644 --- a/recursion/gnark-ffi/go/main.go +++ b/recursion/gnark-ffi/go/main.go @@ -17,11 +17,15 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test/unsafekzg" "github.com/succinctlabs/sp1-recursion-gnark/sp1" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" ) func main() {} @@ -141,3 +145,73 @@ func TestMain() error { return nil } + +//export TestPoseidonBabyBear2 +func TestPoseidonBabyBear2() *C.char { + input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + } + + expectedOutput := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("348670919"), + babybear.NewF("1568590631"), + babybear.NewF("1535107508"), + babybear.NewF("186917780"), + babybear.NewF("587749971"), + babybear.NewF("1827585060"), + babybear.NewF("1218809104"), + babybear.NewF("691692291"), + babybear.NewF("1480664293"), + babybear.NewF("1491566329"), + babybear.NewF("366224457"), + babybear.NewF("490018300"), + babybear.NewF("732772134"), + babybear.NewF("560796067"), + babybear.NewF("484676252"), + babybear.NewF("405025962"), + } + + circuit := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} + assignment := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} + + builder := r1cs.NewBuilder + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) + if err != nil { + return C.CString(err.Error()) + } + + var pk groth16.ProvingKey + pk, err = groth16.DummySetup(r1cs) + if err != nil { + return C.CString(err.Error()) + } + + // Generate witness. + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + return C.CString(err.Error()) + } + + // Generate the proof. + _, err = groth16.Prove(r1cs, pk, witness) + if err != nil { + return C.CString(err.Error()) + } + + return nil +} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go index fb350f180e..edb5a5e4a7 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go @@ -3,11 +3,22 @@ package poseidon2 import ( "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" ) +// Poseidon2 round constants for a state consisting of three BN254 field elements. var RC3 [NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS][WIDTH]frontend.Variable +// Poseidon2 round constaints for a state consisting of 16 BabyBear field elements. + +var RC16 [30][BABYBEAR_WIDTH]babybear.Variable + func init() { + init_rc3() + init_rc16() +} + +func init_rc3() { round := 0 RC3[round] = [WIDTH]frontend.Variable{ @@ -457,3 +468,580 @@ func init() { frontend.Variable("0x0fc1bbceba0590f5abbdffa6d3b35e3297c021a3a409926d0e2d54dc1c84fda6"), } } + +func init_rc16() { + round := 0 + + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2110014213"), + babybear.NewF("3964964605"), + babybear.NewF("2190662774"), + babybear.NewF("2732996483"), + babybear.NewF("640767983"), + babybear.NewF("3403899136"), + babybear.NewF("1716033721"), + babybear.NewF("1606702601"), + babybear.NewF("3759873288"), + babybear.NewF("1466015491"), + babybear.NewF("1498308946"), + babybear.NewF("2844375094"), + babybear.NewF("3042463841"), + babybear.NewF("1969905919"), + babybear.NewF("4109944726"), + babybear.NewF("3925048366"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3706859504"), + babybear.NewF("759122502"), + babybear.NewF("3167665446"), + babybear.NewF("1131812921"), + babybear.NewF("1080754908"), + babybear.NewF("4080114493"), + babybear.NewF("893583089"), + babybear.NewF("2019677373"), + babybear.NewF("3128604556"), + babybear.NewF("580640471"), + babybear.NewF("3277620260"), + babybear.NewF("842931656"), + babybear.NewF("548879852"), + babybear.NewF("3608554714"), + babybear.NewF("3575647916"), + babybear.NewF("81826002"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("4289086263"), + babybear.NewF("1563933798"), + babybear.NewF("1440025885"), + babybear.NewF("184445025"), + babybear.NewF("2598651360"), + babybear.NewF("1396647410"), + babybear.NewF("1575877922"), + babybear.NewF("3303853401"), + babybear.NewF("137125468"), + babybear.NewF("765010148"), + babybear.NewF("633675867"), + babybear.NewF("2037803363"), + babybear.NewF("2573389828"), + babybear.NewF("1895729703"), + babybear.NewF("541515871"), + babybear.NewF("1783382863"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2641856484"), + babybear.NewF("3035743342"), + babybear.NewF("3672796326"), + babybear.NewF("245668751"), + babybear.NewF("2025460432"), + babybear.NewF("201609705"), + babybear.NewF("286217151"), + babybear.NewF("4093475563"), + babybear.NewF("2519572182"), + babybear.NewF("3080699870"), + babybear.NewF("2762001832"), + babybear.NewF("1244250808"), + babybear.NewF("606038199"), + babybear.NewF("3182740831"), + babybear.NewF("73007766"), + babybear.NewF("2572204153"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1196780786"), + babybear.NewF("3447394443"), + babybear.NewF("747167305"), + babybear.NewF("2968073607"), + babybear.NewF("1053214930"), + babybear.NewF("1074411832"), + babybear.NewF("4016794508"), + babybear.NewF("1570312929"), + babybear.NewF("113576933"), + babybear.NewF("4042581186"), + babybear.NewF("3634515733"), + babybear.NewF("1032701597"), + babybear.NewF("2364839308"), + babybear.NewF("3840286918"), + babybear.NewF("888378655"), + babybear.NewF("2520191583"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("36046858"), + babybear.NewF("2927525953"), + babybear.NewF("3912129105"), + babybear.NewF("4004832531"), + babybear.NewF("193772436"), + babybear.NewF("1590247392"), + babybear.NewF("4125818172"), + babybear.NewF("2516251696"), + babybear.NewF("4050945750"), + babybear.NewF("269498914"), + babybear.NewF("1973292656"), + babybear.NewF("891403491"), + babybear.NewF("1845429189"), + babybear.NewF("2611996363"), + babybear.NewF("2310542653"), + babybear.NewF("4071195740"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3505307391"), + babybear.NewF("786445290"), + babybear.NewF("3815313971"), + babybear.NewF("1111591756"), + babybear.NewF("4233279834"), + babybear.NewF("2775453034"), + babybear.NewF("1991257625"), + babybear.NewF("2940505809"), + babybear.NewF("2751316206"), + babybear.NewF("1028870679"), + babybear.NewF("1282466273"), + babybear.NewF("1059053371"), + babybear.NewF("834521354"), + babybear.NewF("138721483"), + babybear.NewF("3100410803"), + babybear.NewF("3843128331"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3878220780"), + babybear.NewF("4058162439"), + babybear.NewF("1478942487"), + babybear.NewF("799012923"), + babybear.NewF("496734827"), + babybear.NewF("3521261236"), + babybear.NewF("755421082"), + babybear.NewF("1361409515"), + babybear.NewF("392099473"), + babybear.NewF("3178453393"), + babybear.NewF("4068463721"), + babybear.NewF("7935614"), + babybear.NewF("4140885645"), + babybear.NewF("2150748066"), + babybear.NewF("1685210312"), + babybear.NewF("3852983224"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2896943075"), + babybear.NewF("3087590927"), + babybear.NewF("992175959"), + babybear.NewF("970216228"), + babybear.NewF("3473630090"), + babybear.NewF("3899670400"), + babybear.NewF("3603388822"), + babybear.NewF("2633488197"), + babybear.NewF("2479406964"), + babybear.NewF("2420952999"), + babybear.NewF("1852516800"), + babybear.NewF("4253075697"), + babybear.NewF("979699862"), + babybear.NewF("1163403191"), + babybear.NewF("1608599874"), + babybear.NewF("3056104448"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3779109343"), + babybear.NewF("536205958"), + babybear.NewF("4183458361"), + babybear.NewF("1649720295"), + babybear.NewF("1444912244"), + babybear.NewF("3122230878"), + babybear.NewF("384301396"), + babybear.NewF("4228198516"), + babybear.NewF("1662916865"), + babybear.NewF("4082161114"), + babybear.NewF("2121897314"), + babybear.NewF("1706239958"), + babybear.NewF("4166959388"), + babybear.NewF("1626054781"), + babybear.NewF("3005858978"), + babybear.NewF("1431907253"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1418914503"), + babybear.NewF("1365856753"), + babybear.NewF("3942715745"), + babybear.NewF("1429155552"), + babybear.NewF("3545642795"), + babybear.NewF("3772474257"), + babybear.NewF("1621094396"), + babybear.NewF("2154399145"), + babybear.NewF("826697382"), + babybear.NewF("1700781391"), + babybear.NewF("3539164324"), + babybear.NewF("652815039"), + babybear.NewF("442484755"), + babybear.NewF("2055299391"), + babybear.NewF("1064289978"), + babybear.NewF("1152335780"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3417648695"), + babybear.NewF("186040114"), + babybear.NewF("3475580573"), + babybear.NewF("2113941250"), + babybear.NewF("1779573826"), + babybear.NewF("1573808590"), + babybear.NewF("3235694804"), + babybear.NewF("2922195281"), + babybear.NewF("1119462702"), + babybear.NewF("3688305521"), + babybear.NewF("1849567013"), + babybear.NewF("667446787"), + babybear.NewF("753897224"), + babybear.NewF("1896396780"), + babybear.NewF("3143026334"), + babybear.NewF("3829603876"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("859661334"), + babybear.NewF("3898844357"), + babybear.NewF("180258337"), + babybear.NewF("2321867017"), + babybear.NewF("3599002504"), + babybear.NewF("2886782421"), + babybear.NewF("3038299378"), + babybear.NewF("1035366250"), + babybear.NewF("2038912197"), + babybear.NewF("2920174523"), + babybear.NewF("1277696101"), + babybear.NewF("2785700290"), + babybear.NewF("3806504335"), + babybear.NewF("3518858933"), + babybear.NewF("654843672"), + babybear.NewF("2127120275"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1548195514"), + babybear.NewF("2378056027"), + babybear.NewF("390914568"), + babybear.NewF("1472049779"), + babybear.NewF("1552596765"), + babybear.NewF("1905886441"), + babybear.NewF("1611959354"), + babybear.NewF("3653263304"), + babybear.NewF("3423946386"), + babybear.NewF("340857935"), + babybear.NewF("2208879480"), + babybear.NewF("139364268"), + babybear.NewF("3447281773"), + babybear.NewF("3777813707"), + babybear.NewF("55640413"), + babybear.NewF("4101901741"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("104929687"), + babybear.NewF("1459980974"), + babybear.NewF("1831234737"), + babybear.NewF("457139004"), + babybear.NewF("2581487628"), + babybear.NewF("2112044563"), + babybear.NewF("3567013861"), + babybear.NewF("2792004347"), + babybear.NewF("576325418"), + babybear.NewF("41126132"), + babybear.NewF("2713562324"), + babybear.NewF("151213722"), + babybear.NewF("2891185935"), + babybear.NewF("546846420"), + babybear.NewF("2939794919"), + babybear.NewF("2543469905"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2191909784"), + babybear.NewF("3315138460"), + babybear.NewF("530414574"), + babybear.NewF("1242280418"), + babybear.NewF("1211740715"), + babybear.NewF("3993672165"), + babybear.NewF("2505083323"), + babybear.NewF("3845798801"), + babybear.NewF("538768466"), + babybear.NewF("2063567560"), + babybear.NewF("3366148274"), + babybear.NewF("1449831887"), + babybear.NewF("2408012466"), + babybear.NewF("294726285"), + babybear.NewF("3943435493"), + babybear.NewF("924016661"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3633138367"), + babybear.NewF("3222789372"), + babybear.NewF("809116305"), + babybear.NewF("30100013"), + babybear.NewF("2655172876"), + babybear.NewF("2564247117"), + babybear.NewF("2478649732"), + babybear.NewF("4113689151"), + babybear.NewF("4120146082"), + babybear.NewF("2512308515"), + babybear.NewF("650406041"), + babybear.NewF("4240012393"), + babybear.NewF("2683508708"), + babybear.NewF("951073977"), + babybear.NewF("3460081988"), + babybear.NewF("339124269"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("130182653"), + babybear.NewF("2755946749"), + babybear.NewF("542600513"), + babybear.NewF("2816103022"), + babybear.NewF("1931786340"), + babybear.NewF("2044470840"), + babybear.NewF("1709908013"), + babybear.NewF("2938369043"), + babybear.NewF("3640399693"), + babybear.NewF("1374470239"), + babybear.NewF("2191149676"), + babybear.NewF("2637495682"), + babybear.NewF("4236394040"), + babybear.NewF("2289358846"), + babybear.NewF("3833368530"), + babybear.NewF("974546524"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3306659113"), + babybear.NewF("2234814261"), + babybear.NewF("1188782305"), + babybear.NewF("223782844"), + babybear.NewF("2248980567"), + babybear.NewF("2309786141"), + babybear.NewF("2023401627"), + babybear.NewF("3278877413"), + babybear.NewF("2022138149"), + babybear.NewF("575851471"), + babybear.NewF("1612560780"), + babybear.NewF("3926656936"), + babybear.NewF("3318548977"), + babybear.NewF("2591863678"), + babybear.NewF("188109355"), + babybear.NewF("4217723909"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1564209905"), + babybear.NewF("2154197895"), + babybear.NewF("2459687029"), + babybear.NewF("2870634489"), + babybear.NewF("1375012945"), + babybear.NewF("1529454825"), + babybear.NewF("306140690"), + babybear.NewF("2855578299"), + babybear.NewF("1246997295"), + babybear.NewF("3024298763"), + babybear.NewF("1915270363"), + babybear.NewF("1218245412"), + babybear.NewF("2479314020"), + babybear.NewF("2989827755"), + babybear.NewF("814378556"), + babybear.NewF("4039775921"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1165280628"), + babybear.NewF("1203983801"), + babybear.NewF("3814740033"), + babybear.NewF("1919627044"), + babybear.NewF("600240215"), + babybear.NewF("773269071"), + babybear.NewF("486685186"), + babybear.NewF("4254048810"), + babybear.NewF("1415023565"), + babybear.NewF("502840102"), + babybear.NewF("4225648358"), + babybear.NewF("510217063"), + babybear.NewF("166444818"), + babybear.NewF("1430745893"), + babybear.NewF("1376516190"), + babybear.NewF("1775891321"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1170945922"), + babybear.NewF("1105391877"), + babybear.NewF("261536467"), + babybear.NewF("1401687994"), + babybear.NewF("1022529847"), + babybear.NewF("2476446456"), + babybear.NewF("2603844878"), + babybear.NewF("3706336043"), + babybear.NewF("3463053714"), + babybear.NewF("1509644517"), + babybear.NewF("588552318"), + babybear.NewF("65252581"), + babybear.NewF("3696502656"), + babybear.NewF("2183330763"), + babybear.NewF("3664021233"), + babybear.NewF("1643809916"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2922875898"), + babybear.NewF("3740690643"), + babybear.NewF("3932461140"), + babybear.NewF("161156271"), + babybear.NewF("2619943483"), + babybear.NewF("4077039509"), + babybear.NewF("2921201703"), + babybear.NewF("2085619718"), + babybear.NewF("2065264646"), + babybear.NewF("2615693812"), + babybear.NewF("3116555433"), + babybear.NewF("246100007"), + babybear.NewF("4281387154"), + babybear.NewF("4046141001"), + babybear.NewF("4027749321"), + babybear.NewF("111611860"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2066954820"), + babybear.NewF("2502099969"), + babybear.NewF("2915053115"), + babybear.NewF("2362518586"), + babybear.NewF("366091708"), + babybear.NewF("2083204932"), + babybear.NewF("4138385632"), + babybear.NewF("3195157567"), + babybear.NewF("1318086382"), + babybear.NewF("521723799"), + babybear.NewF("702443405"), + babybear.NewF("2507670985"), + babybear.NewF("1760347557"), + babybear.NewF("2631999893"), + babybear.NewF("1672737554"), + babybear.NewF("1060867760"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2359801781"), + babybear.NewF("2800231467"), + babybear.NewF("3010357035"), + babybear.NewF("1035997899"), + babybear.NewF("1210110952"), + babybear.NewF("1018506770"), + babybear.NewF("2799468177"), + babybear.NewF("1479380761"), + babybear.NewF("1536021911"), + babybear.NewF("358993854"), + babybear.NewF("579904113"), + babybear.NewF("3432144800"), + babybear.NewF("3625515809"), + babybear.NewF("199241497"), + babybear.NewF("4058304109"), + babybear.NewF("2590164234"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1688530738"), + babybear.NewF("1580733335"), + babybear.NewF("2443981517"), + babybear.NewF("2206270565"), + babybear.NewF("2780074229"), + babybear.NewF("2628739677"), + babybear.NewF("2940123659"), + babybear.NewF("4145206827"), + babybear.NewF("3572278009"), + babybear.NewF("2779607509"), + babybear.NewF("1098718697"), + babybear.NewF("1424913749"), + babybear.NewF("2224415875"), + babybear.NewF("1108922178"), + babybear.NewF("3646272562"), + babybear.NewF("3935186184"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("820046587"), + babybear.NewF("1393386250"), + babybear.NewF("2665818575"), + babybear.NewF("2231782019"), + babybear.NewF("672377010"), + babybear.NewF("1920315467"), + babybear.NewF("1913164407"), + babybear.NewF("2029526876"), + babybear.NewF("2629271820"), + babybear.NewF("384320012"), + babybear.NewF("4112320585"), + babybear.NewF("3131824773"), + babybear.NewF("2347818197"), + babybear.NewF("2220997386"), + babybear.NewF("1772368609"), + babybear.NewF("2579960095"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3544930873"), + babybear.NewF("225847443"), + babybear.NewF("3070082278"), + babybear.NewF("95643305"), + babybear.NewF("3438572042"), + babybear.NewF("3312856509"), + babybear.NewF("615850007"), + babybear.NewF("1863868773"), + babybear.NewF("803582265"), + babybear.NewF("3461976859"), + babybear.NewF("2903025799"), + babybear.NewF("1482092434"), + babybear.NewF("3902972499"), + babybear.NewF("3872341868"), + babybear.NewF("1530411808"), + babybear.NewF("2214923584"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3118792481"), + babybear.NewF("2241076515"), + babybear.NewF("3983669831"), + babybear.NewF("3180915147"), + babybear.NewF("3838626501"), + babybear.NewF("1921630011"), + babybear.NewF("3415351771"), + babybear.NewF("2249953859"), + babybear.NewF("3755081630"), + babybear.NewF("486327260"), + babybear.NewF("1227575720"), + babybear.NewF("3643869379"), + babybear.NewF("2982026073"), + babybear.NewF("2466043731"), + babybear.NewF("1982634375"), + babybear.NewF("3769609014"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2195455495"), + babybear.NewF("2596863283"), + babybear.NewF("4244994973"), + babybear.NewF("1983609348"), + babybear.NewF("4019674395"), + babybear.NewF("3469982031"), + babybear.NewF("1458697570"), + babybear.NewF("1593516217"), + babybear.NewF("1963896497"), + babybear.NewF("3115309118"), + babybear.NewF("1659132465"), + babybear.NewF("2536770756"), + babybear.NewF("3059294171"), + babybear.NewF("2618031334"), + babybear.NewF("2040903247"), + babybear.NewF("3799795076"), + } +} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go new file mode 100644 index 0000000000..9f83956234 --- /dev/null +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -0,0 +1,157 @@ +package poseidon2 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" +) + +const BABYBEAR_WIDTH = 16 +const BABYBEAR_NUM_EXTERNAL_ROUNDS = 8 +const BABYBEAR_NUM_INTERNAL_ROUNDS = 13 +const BABYBEAR_DEGREE = 7 + +type Poseidon2BabyBearChip struct { + api frontend.API + fieldApi *babybear.Chip +} + +func NewBabyBearChip(api frontend.API) *Poseidon2BabyBearChip { + return &Poseidon2BabyBearChip{ + api: api, + fieldApi: babybear.NewChip(api), + } +} + +func (p *Poseidon2BabyBearChip) PermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { + // The initial linear layer. + p.externalLinearLayer(state) + + // The first half of the external rounds. + rounds := BABYBEAR_NUM_EXTERNAL_ROUNDS + BABYBEAR_NUM_INTERNAL_ROUNDS + roundsFBeggining := BABYBEAR_NUM_EXTERNAL_ROUNDS / 2 + for r := 0; r < roundsFBeggining; r++ { + p.addRc(state, RC16[r]) + p.sbox(state) + p.externalLinearLayer(state) + } + + // The internal rounds. + p_end := roundsFBeggining + BABYBEAR_NUM_INTERNAL_ROUNDS + for r := roundsFBeggining; r < p_end; r++ { + state[0] = p.fieldApi.AddF(state[0], RC16[r][0]) + state[0] = p.sboxP(state[0]) + p.diffusionPermuteMut(state) + } + + // The second half of the external rounds. + for r := p_end; r < rounds; r++ { + p.addRc(state, RC16[r]) + p.sbox(state) + p.externalLinearLayer(state) + } +} + +func (p *Poseidon2BabyBearChip) addRc(state *[BABYBEAR_WIDTH]babybear.Variable, rc [BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.AddF(state[i], rc[i]) + } +} + +func (p *Poseidon2BabyBearChip) sboxP(input babybear.Variable) babybear.Variable { + zero := babybear.NewF("0") + inputCpy := p.fieldApi.AddF(input, zero) + inputCpy = p.fieldApi.ReduceSlow(inputCpy) + inputValue := inputCpy.Value + i2 := p.api.Mul(inputValue, inputValue) + i4 := p.api.Mul(i2, i2) + i6 := p.api.Mul(i4, i2) + i7 := p.api.Mul(i6, inputValue) + i7bb := p.fieldApi.ReduceSlow(babybear.Variable{ + Value: i7, + NbBits: 31 * 7, + }) + return i7bb +} + +func (p *Poseidon2BabyBearChip) sbox(state *[BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.sboxP(state[i]) + } +} + +func (p *Poseidon2BabyBearChip) mdsLightPermutation4x4(state []babybear.Variable) { + t01 := p.fieldApi.AddF(state[0], state[1]) + t23 := p.fieldApi.AddF(state[2], state[3]) + t0123 := p.fieldApi.AddF(t01, t23) + t01123 := p.fieldApi.AddF(t0123, state[1]) + t01233 := p.fieldApi.AddF(t0123, state[3]) + state[3] = p.fieldApi.AddF(t01233, p.fieldApi.MulFConst(state[0], 2)) + state[1] = p.fieldApi.AddF(t01123, p.fieldApi.MulFConst(state[2], 2)) + state[0] = p.fieldApi.AddF(t01123, t01) + state[2] = p.fieldApi.AddF(t01233, t23) +} + +func (p *Poseidon2BabyBearChip) externalLinearLayer(state *[BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i += 4 { + p.mdsLightPermutation4x4(state[i : i+4]) + } + + sums := [4]babybear.Variable{ + state[0], + state[1], + state[2], + state[3], + } + for i := 4; i < BABYBEAR_WIDTH; i += 4 { + sums[0] = p.fieldApi.AddF(sums[0], state[i]) + sums[1] = p.fieldApi.AddF(sums[1], state[i+1]) + sums[2] = p.fieldApi.AddF(sums[2], state[i+2]) + sums[3] = p.fieldApi.AddF(sums[3], state[i+3]) + } + + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.AddF(state[i], sums[i%4]) + } +} + +func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { + matInternalDiagM1 := [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2013265919"), + babybear.NewF("1"), + babybear.NewF("2"), + babybear.NewF("4"), + babybear.NewF("8"), + babybear.NewF("16"), + babybear.NewF("32"), + babybear.NewF("64"), + babybear.NewF("128"), + babybear.NewF("256"), + babybear.NewF("512"), + babybear.NewF("1024"), + babybear.NewF("2048"), + babybear.NewF("4096"), + babybear.NewF("8192"), + babybear.NewF("32768"), + } + montyInverse := babybear.NewF("943718400") + p.matmulInternal(state, &matInternalDiagM1) + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.MulF(state[i], montyInverse) + } + +} + +func (p *Poseidon2BabyBearChip) matmulInternal( + state *[BABYBEAR_WIDTH]babybear.Variable, + matInternalDiagM1 *[BABYBEAR_WIDTH]babybear.Variable, +) { + sum := babybear.NewF("0") + for i := 0; i < BABYBEAR_WIDTH; i++ { + sum = p.fieldApi.AddF(sum, state[i]) + } + + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i] = p.fieldApi.AddF(state[i], sum) + } +} diff --git a/recursion/gnark-ffi/go/sp1/prove.go b/recursion/gnark-ffi/go/sp1/prove.go index c921d11435..4a71492390 100644 --- a/recursion/gnark-ffi/go/sp1/prove.go +++ b/recursion/gnark-ffi/go/sp1/prove.go @@ -1,6 +1,7 @@ package sp1 import ( + "bufio" "encoding/json" "os" @@ -30,7 +31,8 @@ func Prove(dataDir string, witnessPath string) Proof { panic(err) } pk := plonk.NewProvingKey(ecc.BN254) - pk.UnsafeReadFrom(pkFile) + bufReader := bufio.NewReaderSize(pkFile, 1024*1024) + pk.UnsafeReadFrom(bufReader) // Read the verifier key. vkFile, err := os.Open(dataDir + "/" + VK_PATH) diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index f3f3b24a51..ccde520953 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -68,6 +68,7 @@ func (circuit *Circuit) Define(api frontend.API) error { } hashAPI := poseidon2.NewChip(api) + hashBabyBearAPI := poseidon2.NewBabyBearChip(api) fieldAPI := babybear.NewChip(api) vars := make(map[string]frontend.Variable) felts := make(map[string]babybear.Variable) @@ -132,6 +133,15 @@ func (circuit *Circuit) Define(api frontend.API) error { vars[cs.Args[0][0]] = state[0] vars[cs.Args[1][0]] = state[1] vars[cs.Args[2][0]] = state[2] + case "PermuteBabyBear": + var state [16]babybear.Variable + for i := 0; i < 16; i++ { + state[i] = felts[cs.Args[i][0]] + } + hashBabyBearAPI.PermuteMut(&state) + for i := 0; i < 16; i++ { + felts[cs.Args[i][0]] = state[i] + } case "SelectV": vars[cs.Args[0][0]] = api.Select(vars[cs.Args[1][0]], vars[cs.Args[2][0]], vars[cs.Args[3][0]]) case "SelectF": diff --git a/recursion/gnark-ffi/go/sp1/test.go b/recursion/gnark-ffi/go/sp1/test.go new file mode 100644 index 0000000000..8d2aa8f0ae --- /dev/null +++ b/recursion/gnark-ffi/go/sp1/test.go @@ -0,0 +1,31 @@ +package sp1 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" +) + +type TestPoseidon2BabyBearCircuit struct { + Input [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` + ExpectedOutput [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` +} + +func (circuit *TestPoseidon2BabyBearCircuit) Define(api frontend.API) error { + poseidon2BabyBearChip := poseidon2.NewBabyBearChip(api) + fieldApi := babybear.NewChip(api) + + zero := babybear.NewF("0") + input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{} + for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { + input[i] = fieldApi.AddF(circuit.Input[i], zero) + } + + poseidon2BabyBearChip.PermuteMut(&input) + + for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { + fieldApi.AssertIsEqualF(circuit.ExpectedOutput[i], input[i]) + } + + return nil +} diff --git a/recursion/gnark-ffi/src/ffi.rs b/recursion/gnark-ffi/src/ffi.rs deleted file mode 100644 index d7ecf9d612..0000000000 --- a/recursion/gnark-ffi/src/ffi.rs +++ /dev/null @@ -1,142 +0,0 @@ -#![allow(unused)] - -//! FFI bindings for the Go code. The functions exported in this module are safe to call from Rust. -//! All C strings and other C memory should be freed in Rust, including C Strings returned by Go. -//! Although we cast to *mut c_char because the Go signatures can't be immutable, the Go functions -//! should not modify the strings. - -use crate::PlonkBn254Proof; -use cfg_if::cfg_if; -use std::ffi::{c_char, CString}; - -#[allow(warnings, clippy::all)] -mod bind { - #[cfg(feature = "plonk")] - include!(concat!(env!("OUT_DIR"), "/bindings.rs")); -} -use bind::*; - -pub fn prove_plonk_bn254(data_dir: &str, witness_path: &str) -> PlonkBn254Proof { - cfg_if! { - if #[cfg(feature = "plonk")] { - let data_dir = CString::new(data_dir).expect("CString::new failed"); - let witness_path = CString::new(witness_path).expect("CString::new failed"); - - let proof = unsafe { - let proof = bind::ProvePlonkBn254( - data_dir.as_ptr() as *mut c_char, - witness_path.as_ptr() as *mut c_char, - ); - // Safety: The pointer is returned from the go code and is guaranteed to be valid. - *proof - }; - - proof.into_rust() - } else { - panic!("plonk feature not enabled"); - } - } -} - -pub fn build_plonk_bn254(data_dir: &str) { - cfg_if! { - if #[cfg(feature = "plonk")] { - let data_dir = CString::new(data_dir).expect("CString::new failed"); - - unsafe { - bind::BuildPlonkBn254(data_dir.as_ptr() as *mut c_char); - } - } else { - panic!("plonk feature not enabled"); - - } - } -} - -pub fn verify_plonk_bn254( - data_dir: &str, - proof: &str, - vkey_hash: &str, - committed_values_digest: &str, -) -> Result<(), String> { - cfg_if! { - if #[cfg(feature = "plonk")] { - let data_dir = CString::new(data_dir).expect("CString::new failed"); - let proof = CString::new(proof).expect("CString::new failed"); - let vkey_hash = CString::new(vkey_hash).expect("CString::new failed"); - let committed_values_digest = - CString::new(committed_values_digest).expect("CString::new failed"); - - let err_ptr = unsafe { - bind::VerifyPlonkBn254( - data_dir.as_ptr() as *mut c_char, - proof.as_ptr() as *mut c_char, - vkey_hash.as_ptr() as *mut c_char, - committed_values_digest.as_ptr() as *mut c_char, - ) - }; - if err_ptr.is_null() { - Ok(()) - } else { - // Safety: The error message is returned from the go code and is guaranteed to be valid. - let err = unsafe { CString::from_raw(err_ptr) }; - Err(err.into_string().unwrap()) - } - } else { - panic!("plonk feature not enabled"); - } - } -} - -pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { - cfg_if! { - if #[cfg(feature = "plonk")] { - unsafe { - let witness_json = CString::new(witness_json).expect("CString::new failed"); - let build_dir = CString::new(constraints_json).expect("CString::new failed"); - let err_ptr = bind::TestPlonkBn254( - witness_json.as_ptr() as *mut c_char, - build_dir.as_ptr() as *mut c_char, - ); - if !err_ptr.is_null() { - // Safety: The error message is returned from the go code and is guaranteed to be valid. - let err = CString::from_raw(err_ptr); - panic!("TestPlonkBn254 failed: {}", err.into_string().unwrap()); - } - } - } else { - panic!("plonk feature not enabled"); - } - } -} - -/// Converts a C string into a Rust String. -/// -/// # Safety -/// This function frees the string memory, so the caller must ensure that the pointer is not used -/// after this function is called. -unsafe fn c_char_ptr_to_string(input: *mut c_char) -> String { - unsafe { - CString::from_raw(input) // Converts a pointer that C uses into a CString - .into_string() - .expect("CString::into_string failed") - } -} - -#[cfg(feature = "plonk")] -impl C_PlonkBn254Proof { - /// Converts a C PlonkBn254Proof into a Rust PlonkBn254Proof, freeing the C strings. - fn into_rust(self) -> PlonkBn254Proof { - // Safety: The raw pointers are not used anymore after converted into Rust strings. - unsafe { - PlonkBn254Proof { - public_inputs: [ - c_char_ptr_to_string(self.PublicInputs[0]), - c_char_ptr_to_string(self.PublicInputs[1]), - ], - encoded_proof: c_char_ptr_to_string(self.EncodedProof), - raw_proof: c_char_ptr_to_string(self.RawProof), - } - } - } -} diff --git a/recursion/gnark-ffi/src/ffi/docker.rs b/recursion/gnark-ffi/src/ffi/docker.rs new file mode 100644 index 0000000000..46b21c8423 --- /dev/null +++ b/recursion/gnark-ffi/src/ffi/docker.rs @@ -0,0 +1,114 @@ +use sp1_core::SP1_CIRCUIT_VERSION; + +use crate::PlonkBn254Proof; +use std::io::Write; +use std::process::Command; + +/// Checks that docker is installed and running. +fn check_docker() -> bool { + let output = Command::new("docker").arg("info").output(); + output.is_ok() && output.unwrap().status.success() +} + +/// Panics if docker is not installed and running. +fn assert_docker() { + if !check_docker() { + panic!("Failed to run `docker info`. Please ensure that docker is installed and running."); + } +} + +fn get_docker_image() -> String { + std::env::var("SP1_GNARK_IMAGE") + .unwrap_or_else(|_| format!("ghcr.io/succinctlabs/sp1-gnark:{}", SP1_CIRCUIT_VERSION)) +} + +/// Calls `docker run` with the given arguments and bind mounts. +fn call_docker(args: &[&str], mounts: &[(&str, &str)]) -> anyhow::Result<()> { + log::info!("Running {} in docker", args[0]); + let mut cmd = Command::new("docker"); + cmd.args(["run", "--rm"]); + for (src, dest) in mounts { + cmd.arg("-v").arg(format!("{}:{}", src, dest)); + } + cmd.arg(get_docker_image()); + cmd.args(args); + if !cmd.status()?.success() { + log::error!("Failed to run `docker run`: {:?}", cmd); + return Err(anyhow::anyhow!("docker command failed")); + } + Ok(()) +} + +pub fn prove_plonk_bn254(data_dir: &str, witness_path: &str) -> PlonkBn254Proof { + let output_file = tempfile::NamedTempFile::new().unwrap(); + let mounts = [ + (data_dir, "/circuit"), + (witness_path, "/witness"), + (output_file.path().to_str().unwrap(), "/output"), + ]; + assert_docker(); + call_docker(&["prove-plonk", "/circuit", "/witness", "/output"], &mounts) + .expect("failed to prove with docker"); + bincode::deserialize_from(&output_file).expect("failed to deserialize result") +} + +pub fn build_plonk_bn254(data_dir: &str) { + let circuit_dir = if data_dir.ends_with("dev") { + "/circuit_dev" + } else { + "/circuit" + }; + let mounts = [(data_dir, circuit_dir)]; + assert_docker(); + call_docker(&["build-plonk", circuit_dir], &mounts).expect("failed to build with docker"); +} + +pub fn verify_plonk_bn254( + data_dir: &str, + proof: &str, + vkey_hash: &str, + committed_values_digest: &str, +) -> Result<(), String> { + // Write proof string to a file since it can be large. + let mut proof_file = tempfile::NamedTempFile::new().unwrap(); + proof_file.write_all(proof.as_bytes()).unwrap(); + let output_file = tempfile::NamedTempFile::new().unwrap(); + let mounts = [ + (data_dir, "/circuit"), + (proof_file.path().to_str().unwrap(), "/proof"), + (output_file.path().to_str().unwrap(), "/output"), + ]; + assert_docker(); + call_docker( + &[ + "verify-plonk", + "/circuit", + "/proof", + vkey_hash, + committed_values_digest, + "/output", + ], + &mounts, + ) + .expect("failed to verify with docker"); + let result = std::fs::read_to_string(output_file.path()).unwrap(); + if result == "OK" { + Ok(()) + } else { + Err(result) + } +} + +pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { + let mounts = [ + (constraints_json, "/constraints"), + (witness_json, "/witness"), + ]; + assert_docker(); + call_docker(&["test-plonk", "/constraints", "/witness"], &mounts) + .expect("failed to test with docker"); +} + +pub fn test_babybear_poseidon2() { + unimplemented!() +} diff --git a/recursion/gnark-ffi/src/ffi/mod.rs b/recursion/gnark-ffi/src/ffi/mod.rs new file mode 100644 index 0000000000..42358f90e8 --- /dev/null +++ b/recursion/gnark-ffi/src/ffi/mod.rs @@ -0,0 +1,9 @@ +cfg_if::cfg_if! { + if #[cfg(feature = "native")] { + mod native; + pub use native::*; + } else { + mod docker; + pub use docker::*; + } +} diff --git a/recursion/gnark-ffi/src/ffi/native.rs b/recursion/gnark-ffi/src/ffi/native.rs new file mode 100644 index 0000000000..9d88a6b5b6 --- /dev/null +++ b/recursion/gnark-ffi/src/ffi/native.rs @@ -0,0 +1,144 @@ +#![allow(unused)] + +//! FFI bindings for the Go code. The functions exported in this module are safe to call from Rust. +//! All C strings and other C memory should be freed in Rust, including C Strings returned by Go. +//! Although we cast to *mut c_char because the Go signatures can't be immutable, the Go functions +//! should not modify the strings. + +use crate::PlonkBn254Proof; +use cfg_if::cfg_if; +use sp1_core::SP1_CIRCUIT_VERSION; +use std::ffi::{c_char, CString}; + +#[allow(warnings, clippy::all)] +mod bind { + include!(concat!(env!("OUT_DIR"), "/bindings.rs")); +} +use bind::*; + +pub fn prove_plonk_bn254(data_dir: &str, witness_path: &str) -> PlonkBn254Proof { + let data_dir = CString::new(data_dir).expect("CString::new failed"); + let witness_path = CString::new(witness_path).expect("CString::new failed"); + + let proof = unsafe { + let proof = bind::ProvePlonkBn254( + data_dir.as_ptr() as *mut c_char, + witness_path.as_ptr() as *mut c_char, + ); + // Safety: The pointer is returned from the go code and is guaranteed to be valid. + *proof + }; + + proof.into_rust() +} + +pub fn build_plonk_bn254(data_dir: &str) { + let data_dir = CString::new(data_dir).expect("CString::new failed"); + + unsafe { + bind::BuildPlonkBn254(data_dir.as_ptr() as *mut c_char); + } +} + +pub fn verify_plonk_bn254( + data_dir: &str, + proof: &str, + vkey_hash: &str, + committed_values_digest: &str, +) -> Result<(), String> { + let data_dir = CString::new(data_dir).expect("CString::new failed"); + let proof = CString::new(proof).expect("CString::new failed"); + let vkey_hash = CString::new(vkey_hash).expect("CString::new failed"); + let committed_values_digest = + CString::new(committed_values_digest).expect("CString::new failed"); + + let err_ptr = unsafe { + bind::VerifyPlonkBn254( + data_dir.as_ptr() as *mut c_char, + proof.as_ptr() as *mut c_char, + vkey_hash.as_ptr() as *mut c_char, + committed_values_digest.as_ptr() as *mut c_char, + ) + }; + if err_ptr.is_null() { + Ok(()) + } else { + // Safety: The error message is returned from the go code and is guaranteed to be valid. + let err = unsafe { CString::from_raw(err_ptr) }; + Err(err.into_string().unwrap()) + } +} + +pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { + unsafe { + let witness_json = CString::new(witness_json).expect("CString::new failed"); + let build_dir = CString::new(constraints_json).expect("CString::new failed"); + let err_ptr = bind::TestPlonkBn254( + witness_json.as_ptr() as *mut c_char, + build_dir.as_ptr() as *mut c_char, + ); + if !err_ptr.is_null() { + // Safety: The error message is returned from the go code and is guaranteed to be valid. + let err = CString::from_raw(err_ptr); + panic!("TestPlonkBn254 failed: {}", err.into_string().unwrap()); + } + } +} + +pub fn test_babybear_poseidon2() { + unsafe { + let err_ptr = bind::TestPoseidonBabyBear2(); + if !err_ptr.is_null() { + // Safety: The error message is returned from the go code and is guaranteed to be valid. + let err = CString::from_raw(err_ptr); + panic!("TestPlonkBn254 failed: {}", err.into_string().unwrap()); + } + } +} + +/// Converts a C string into a Rust String. +/// +/// # Safety +/// This function frees the string memory, so the caller must ensure that the pointer is not used +/// after this function is called. +unsafe fn c_char_ptr_to_string(input: *mut c_char) -> String { + unsafe { + CString::from_raw(input) // Converts a pointer that C uses into a CString + .into_string() + .expect("CString::into_string failed") + } +} + +impl C_PlonkBn254Proof { + /// Converts a C PlonkBn254Proof into a Rust PlonkBn254Proof, freeing the C strings. + fn into_rust(self) -> PlonkBn254Proof { + // Safety: The raw pointers are not used anymore after converted into Rust strings. + unsafe { + PlonkBn254Proof { + public_inputs: [ + c_char_ptr_to_string(self.PublicInputs[0]), + c_char_ptr_to_string(self.PublicInputs[1]), + ], + encoded_proof: c_char_ptr_to_string(self.EncodedProof), + raw_proof: c_char_ptr_to_string(self.RawProof), + plonk_vkey_hash: [0; 32], + } + } + } +} + +#[cfg(test)] +mod tests { + use p3_baby_bear::BabyBear; + use p3_field::AbstractField; + use p3_symmetric::Permutation; + + #[test] + pub fn test_babybear_poseidon2() { + let perm = sp1_core::utils::inner_perm(); + let zeros = [BabyBear::zero(); 16]; + let result = perm.permute(zeros); + println!("{:?}", result); + super::test_babybear_poseidon2(); + } +} diff --git a/recursion/gnark-ffi/src/lib.rs b/recursion/gnark-ffi/src/lib.rs index 26940e863b..670a9b08fe 100644 --- a/recursion/gnark-ffi/src/lib.rs +++ b/recursion/gnark-ffi/src/lib.rs @@ -1,5 +1,7 @@ mod babybear; + pub mod ffi; + pub mod plonk_bn254; pub mod witness; diff --git a/recursion/gnark-ffi/src/plonk_bn254.rs b/recursion/gnark-ffi/src/plonk_bn254.rs index 74727c328b..8d13c09946 100644 --- a/recursion/gnark-ffi/src/plonk_bn254.rs +++ b/recursion/gnark-ffi/src/plonk_bn254.rs @@ -4,13 +4,14 @@ use std::{ path::{Path, PathBuf}, }; -use crate::{ - ffi::{build_plonk_bn254, prove_plonk_bn254, test_plonk_bn254, verify_plonk_bn254}, - witness::GnarkWitness, -}; +use crate::ffi::{build_plonk_bn254, prove_plonk_bn254, test_plonk_bn254, verify_plonk_bn254}; +use crate::witness::GnarkWitness; use num_bigint::BigUint; use serde::{Deserialize, Serialize}; +use sha2::Digest; +use sha2::Sha256; +use sp1_core::SP1_CIRCUIT_VERSION; use sp1_recursion_compiler::{ constraints::Constraint, ir::{Config, Witness}, @@ -26,6 +27,7 @@ pub struct PlonkBn254Proof { pub public_inputs: [String; 2], pub encoded_proof: String, pub raw_proof: String, + pub plonk_vkey_hash: [u8; 32], } impl PlonkBn254Prover { @@ -34,6 +36,12 @@ impl PlonkBn254Prover { Self } + pub fn get_vkey_hash(build_dir: &Path) -> [u8; 32] { + let vkey_path = build_dir.join("vk.bin"); + let vk_bin_bytes = std::fs::read(vkey_path).unwrap(); + Sha256::digest(vk_bin_bytes).into() + } + /// Executes the prover in testing mode with a circuit definition and witness. pub fn test(constraints: Vec, witness: Witness) { let serialized = serde_json::to_string(&constraints).unwrap(); @@ -81,7 +89,13 @@ impl PlonkBn254Prover { .unwrap(); let sp1_verifier_path = build_dir.join("SP1Verifier.sol"); - let sp1_verifier_str = include_str!("../assets/SP1Verifier.txt"); + let vkey_hash = Self::get_vkey_hash(&build_dir); + let sp1_verifier_str = include_str!("../assets/SP1Verifier.txt") + .replace("{SP1_CIRCUIT_VERSION}", SP1_CIRCUIT_VERSION) + .replace( + "{VKEY_HASH}", + format!("0x{}", hex::encode(vkey_hash)).as_str(), + ); let mut sp1_verifier_file = File::create(sp1_verifier_path).unwrap(); sp1_verifier_file .write_all(sp1_verifier_str.as_bytes()) @@ -95,7 +109,7 @@ impl PlonkBn254Prover { .unwrap(); } - /// Generates a PLONK proof by sending a request to the Gnark server. + /// Generates a PLONK proof given a witness. pub fn prove(&self, witness: Witness, build_dir: PathBuf) -> PlonkBn254Proof { // Write witness. let mut witness_file = tempfile::NamedTempFile::new().unwrap(); @@ -103,10 +117,12 @@ impl PlonkBn254Prover { let serialized = serde_json::to_string(&gnark_witness).unwrap(); witness_file.write_all(serialized.as_bytes()).unwrap(); - prove_plonk_bn254( + let mut proof = prove_plonk_bn254( build_dir.to_str().unwrap(), witness_file.path().to_str().unwrap(), - ) + ); + proof.plonk_vkey_hash = Self::get_vkey_hash(&build_dir); + proof } /// Verify a PLONK proof and verify that the supplied vkey_hash and committed_values_digest match. @@ -117,6 +133,9 @@ impl PlonkBn254Prover { committed_values_digest: &BigUint, build_dir: &Path, ) { + if proof.plonk_vkey_hash != Self::get_vkey_hash(build_dir) { + panic!("Proof vkey hash does not match circuit vkey hash, it was generated with a different circuit."); + } verify_plonk_bn254( build_dir.to_str().unwrap(), &proof.raw_proof, diff --git a/recursion/groth16/constraints.json b/recursion/groth16/constraints.json deleted file mode 100644 index 28fc1fdac3..0000000000 --- a/recursion/groth16/constraints.json +++ /dev/null @@ -1 +0,0 @@ -[{"opcode":"ImmV","args":[["var0"],["100"]]},{"opcode":"Num2BitsV","args":[["var1","var2","var3","var4","var5","var6","var7","var8","var9","var10","var11","var12","var13","var14","var15","var16","var17","var18","var19","var20","var21","var22","var23","var24","var25","var26","var27","var28","var29","var30","var31","var32"],["var0"],["32"]]},{"opcode":"ImmV","args":[["backend0"],["0"]]},{"opcode":"AssertEqV","args":[["var1"],["backend0"]]},{"opcode":"ImmV","args":[["backend1"],["0"]]},{"opcode":"AssertEqV","args":[["var2"],["backend1"]]},{"opcode":"ImmV","args":[["backend2"],["1"]]},{"opcode":"AssertEqV","args":[["var3"],["backend2"]]},{"opcode":"ImmV","args":[["backend3"],["0"]]},{"opcode":"AssertEqV","args":[["var4"],["backend3"]]},{"opcode":"ImmV","args":[["backend4"],["0"]]},{"opcode":"AssertEqV","args":[["var5"],["backend4"]]},{"opcode":"ImmV","args":[["backend5"],["1"]]},{"opcode":"AssertEqV","args":[["var6"],["backend5"]]},{"opcode":"ImmV","args":[["backend6"],["1"]]},{"opcode":"AssertEqV","args":[["var7"],["backend6"]]},{"opcode":"ImmV","args":[["backend7"],["0"]]},{"opcode":"AssertEqV","args":[["var8"],["backend7"]]},{"opcode":"ImmV","args":[["backend8"],["0"]]},{"opcode":"AssertEqV","args":[["var9"],["backend8"]]},{"opcode":"ImmV","args":[["backend9"],["0"]]},{"opcode":"AssertEqV","args":[["var10"],["backend9"]]},{"opcode":"ImmV","args":[["backend10"],["0"]]},{"opcode":"AssertEqV","args":[["var11"],["backend10"]]},{"opcode":"ImmV","args":[["backend11"],["0"]]},{"opcode":"AssertEqV","args":[["var12"],["backend11"]]},{"opcode":"ImmV","args":[["backend12"],["0"]]},{"opcode":"AssertEqV","args":[["var13"],["backend12"]]},{"opcode":"ImmV","args":[["backend13"],["0"]]},{"opcode":"AssertEqV","args":[["var14"],["backend13"]]},{"opcode":"ImmV","args":[["backend14"],["0"]]},{"opcode":"AssertEqV","args":[["var15"],["backend14"]]},{"opcode":"ImmV","args":[["backend15"],["0"]]},{"opcode":"AssertEqV","args":[["var16"],["backend15"]]},{"opcode":"ImmV","args":[["backend16"],["0"]]},{"opcode":"AssertEqV","args":[["var17"],["backend16"]]},{"opcode":"ImmV","args":[["backend17"],["0"]]},{"opcode":"AssertEqV","args":[["var18"],["backend17"]]},{"opcode":"ImmV","args":[["backend18"],["0"]]},{"opcode":"AssertEqV","args":[["var19"],["backend18"]]},{"opcode":"ImmV","args":[["backend19"],["0"]]},{"opcode":"AssertEqV","args":[["var20"],["backend19"]]},{"opcode":"ImmV","args":[["backend20"],["0"]]},{"opcode":"AssertEqV","args":[["var21"],["backend20"]]},{"opcode":"ImmV","args":[["backend21"],["0"]]},{"opcode":"AssertEqV","args":[["var22"],["backend21"]]},{"opcode":"ImmV","args":[["backend22"],["0"]]},{"opcode":"AssertEqV","args":[["var23"],["backend22"]]},{"opcode":"ImmV","args":[["backend23"],["0"]]},{"opcode":"AssertEqV","args":[["var24"],["backend23"]]},{"opcode":"ImmV","args":[["backend24"],["0"]]},{"opcode":"AssertEqV","args":[["var25"],["backend24"]]},{"opcode":"ImmV","args":[["backend25"],["0"]]},{"opcode":"AssertEqV","args":[["var26"],["backend25"]]},{"opcode":"ImmV","args":[["backend26"],["0"]]},{"opcode":"AssertEqV","args":[["var27"],["backend26"]]},{"opcode":"ImmV","args":[["backend27"],["0"]]},{"opcode":"AssertEqV","args":[["var28"],["backend27"]]},{"opcode":"ImmV","args":[["backend28"],["0"]]},{"opcode":"AssertEqV","args":[["var29"],["backend28"]]},{"opcode":"ImmV","args":[["backend29"],["0"]]},{"opcode":"AssertEqV","args":[["var30"],["backend29"]]},{"opcode":"ImmV","args":[["backend30"],["0"]]},{"opcode":"AssertEqV","args":[["var31"],["backend30"]]},{"opcode":"ImmV","args":[["backend31"],["0"]]},{"opcode":"AssertEqV","args":[["var32"],["backend31"]]}] \ No newline at end of file diff --git a/recursion/groth16/lib/libbabybear.a b/recursion/groth16/lib/libbabybear.a deleted file mode 100644 index e047c94965..0000000000 Binary files a/recursion/groth16/lib/libbabybear.a and /dev/null differ diff --git a/recursion/groth16/main b/recursion/groth16/main deleted file mode 100755 index 126a88bb45..0000000000 Binary files a/recursion/groth16/main and /dev/null differ diff --git a/recursion/groth16/witness.json b/recursion/groth16/witness.json deleted file mode 100644 index ed4386877e..0000000000 --- a/recursion/groth16/witness.json +++ /dev/null @@ -1 +0,0 @@ -{"vars":["999"],"felts":["999"],"exts":[["999","0","0","0"]]} \ No newline at end of file diff --git a/recursion/program/Cargo.toml b/recursion/program/Cargo.toml index 47dd52380c..decfb2151e 100644 --- a/recursion/program/Cargo.toml +++ b/recursion/program/Cargo.toml @@ -20,7 +20,10 @@ p3-poseidon2 = { workspace = true } sp1-recursion-core = { path = "../core" } sp1-recursion-compiler = { path = "../compiler" } sp1-core = { path = "../../core" } -itertools = "0.12.1" +itertools = "0.13.0" serde = { version = "1.0.201", features = ["derive"] } rand = "0.8.5" tracing = "0.1.40" + +[features] +debug = ["sp1-core/debug"] \ No newline at end of file diff --git a/recursion/program/src/challenger.rs b/recursion/program/src/challenger.rs index ceea9168ac..b6943d27f7 100644 --- a/recursion/program/src/challenger.rs +++ b/recursion/program/src/challenger.rs @@ -298,6 +298,7 @@ mod tests { use sp1_recursion_compiler::ir::Felt; use sp1_recursion_compiler::ir::Usize; use sp1_recursion_compiler::ir::Var; + use sp1_recursion_core::runtime::PERMUTATION_WIDTH; use sp1_recursion_core::stark::utils::run_test_recursion; use sp1_recursion_core::stark::utils::TestConfig; diff --git a/recursion/program/src/constraints.rs b/recursion/program/src/constraints.rs index 8ea42a2d72..67eeea5681 100644 --- a/recursion/program/src/constraints.rs +++ b/recursion/program/src/constraints.rs @@ -158,6 +158,7 @@ where #[cfg(test)] mod tests { use itertools::{izip, Itertools}; + use rand::{thread_rng, Rng}; use serde::{de::DeserializeOwned, Serialize}; use sp1_core::{ io::SP1Stdin, @@ -171,7 +172,7 @@ mod tests { use sp1_recursion_core::stark::utils::{run_test_recursion, TestConfig}; use p3_challenger::{CanObserve, FieldChallenger}; - use sp1_recursion_compiler::{asm::AsmBuilder, prelude::ExtConst}; + use sp1_recursion_compiler::{asm::AsmBuilder, ir::Felt, prelude::ExtConst}; use p3_commit::{Pcs, PolynomialSpace}; @@ -355,4 +356,33 @@ mod tests { let program = builder.compile_program(); run_test_recursion(program, None, TestConfig::All); } + + #[test] + fn test_exp_reverse_bit_len_fast() { + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + + let mut rng = thread_rng(); + + // Initialize a builder. + let mut builder = AsmBuilder::::default(); + + // Get a random var with `NUM_BITS` bits. + let x_val: F = rng.gen(); + + // Materialize the number as a var + let x_felt: Felt<_> = builder.eval(x_val); + let x_bits = builder.num2bits_f(x_felt); + + let result = builder.exp_reverse_bits_len_fast(x_felt, &x_bits, 5); + let expected_val = builder.exp_reverse_bits_len(x_felt, &x_bits, 5); + + builder.assert_felt_eq(expected_val, result); + builder.halt(); + + let program = builder.compile_program(); + + run_test_recursion(program, None, TestConfig::All); + } } diff --git a/recursion/program/src/fri/mod.rs b/recursion/program/src/fri/mod.rs index 53affc3e3e..f274a65bd7 100644 --- a/recursion/program/src/fri/mod.rs +++ b/recursion/program/src/fri/mod.rs @@ -137,10 +137,7 @@ where let folded_eval: Ext = builder.eval(C::F::zero()); let two_adic_generator_f = config.get_two_adic_generator(builder, log_max_height); - let two_adic_gen_ext = two_adic_generator_f.to_operand().symbolic(); - let two_adic_generator_ef: Ext<_, _> = builder.eval(two_adic_gen_ext); - - let x = builder.exp_reverse_bits_len(two_adic_generator_ef, index_bits, log_max_height); + let x = builder.exp_reverse_bits_len_fast(two_adic_generator_f, index_bits, log_max_height); let log_max_height = log_max_height.materialize(builder); builder @@ -190,10 +187,10 @@ where .if_eq(index_sibling_mod_2, C::N::zero()) .then_or_else( |builder| { - builder.assign(xs_0, x * two_adic_generator_one); + builder.assign(xs_0, x * two_adic_generator_one.to_operand().symbolic()); }, |builder| { - builder.assign(xs_1, x * two_adic_generator_one); + builder.assign(xs_1, x * two_adic_generator_one.to_operand().symbolic()); }, ); diff --git a/recursion/program/src/fri/two_adic_pcs.rs b/recursion/program/src/fri/two_adic_pcs.rs index aee46cfd28..354ff66611 100644 --- a/recursion/program/src/fri/two_adic_pcs.rs +++ b/recursion/program/src/fri/two_adic_pcs.rs @@ -120,7 +120,7 @@ pub fn verify_two_adic_pcs( let two_adic_generator = config.get_two_adic_generator(builder, log_height); builder.cycle_tracker("exp_reverse_bits_len"); - let two_adic_generator_exp = builder.exp_reverse_bits_len( + let two_adic_generator_exp = builder.exp_reverse_bits_len_fast( two_adic_generator, &index_bits_shifted, log_height, diff --git a/recursion/program/src/machine/core.rs b/recursion/program/src/machine/core.rs index 1b86351109..e8a547cc09 100644 --- a/recursion/program/src/machine/core.rs +++ b/recursion/program/src/machine/core.rs @@ -160,6 +160,11 @@ where let cumulative_sum: Ext<_, _> = builder.eval(C::EF::zero().cons()); let current_pc: Felt<_> = builder.uninit(); let exit_code: Felt<_> = builder.uninit(); + + // Range check that the number of proofs is sufficiently small. + let num_shard_proofs: Var<_> = shard_proofs.len().materialize(builder); + builder.range_check_v(num_shard_proofs, 16); + // Verify proofs, validate transitions, and update accumulation variables. builder.range(0, shard_proofs.len()).for_each(|i, builder| { // Load the proof. @@ -263,6 +268,9 @@ where // Assert that exit code is the same for all proofs. builder.assert_felt_eq(exit_code, public_values.exit_code); + // Assert that the exit code is zero (success) for all proofs. + builder.assert_felt_eq(exit_code, C::F::zero()); + // Assert that the deferred proof digest is the same for all proofs. for (digest, current_digest) in deferred_proofs_digest .iter() diff --git a/recursion/program/src/machine/deferred.rs b/recursion/program/src/machine/deferred.rs index 2ae232ab73..a370a5b3e4 100644 --- a/recursion/program/src/machine/deferred.rs +++ b/recursion/program/src/machine/deferred.rs @@ -187,7 +187,8 @@ where let element = builder.get(&proof.public_values, j); challenger.observe(builder, element); } - // verify the proof. + + // Verify the proof. StarkVerifier::::verify_shard( builder, &compress_vk, diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index aec5d9453d..c9c8e4000e 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -3,6 +3,8 @@ use p3_commit::TwoAdicMultiplicativeCoset; use p3_field::AbstractField; use p3_field::TwoAdicField; use sp1_core::air::MachineAir; +use sp1_core::air::PublicValues; +use sp1_core::air::Word; use sp1_core::stark::Com; use sp1_core::stark::GenericVerifierConstraintFolder; use sp1_core::stark::ShardProof; @@ -12,7 +14,6 @@ use sp1_core::stark::StarkMachine; use sp1_core::stark::StarkVerifyingKey; use sp1_recursion_compiler::ir::Array; use sp1_recursion_compiler::ir::Ext; -use sp1_recursion_compiler::ir::ExtConst; use sp1_recursion_compiler::ir::SymbolicExt; use sp1_recursion_compiler::ir::SymbolicVar; use sp1_recursion_compiler::ir::Var; @@ -94,48 +95,6 @@ impl<'a, SC: StarkGenericConfig, A: MachineAir> VerifyingKeyHint<'a, SC } } -impl StarkRecursiveVerifier for StarkMachine -where - C::F: TwoAdicField, - SC: StarkGenericConfig< - Val = C::F, - Challenge = C::EF, - Domain = TwoAdicMultiplicativeCoset, - >, - A: MachineAir + for<'a> Air>, - C::F: TwoAdicField, - C::EF: TwoAdicField, - Com: Into<[SC::Val; DIGEST_SIZE]>, -{ - fn verify_shard( - &self, - builder: &mut Builder, - vk: &VerifyingKeyVariable, - pcs: &TwoAdicFriPcsVariable, - challenger: &mut DuplexChallengerVariable, - proof: &ShardProofVariable, - is_complete: impl Into::N>>, - ) { - // Verify the shard proof. - StarkVerifier::::verify_shard(builder, vk, pcs, self, challenger, proof); - - // Verify that the cumulative sum of the chip is zero if the shard is complete. - let cumulative_sum: Ext<_, _> = builder.uninit(); - builder - .range(0, proof.opened_values.chips.len()) - .for_each(|i, builder| { - let values = builder.get(&proof.opened_values.chips, i); - builder.assign(cumulative_sum, cumulative_sum + values.cumulative_sum); - }); - - builder - .if_eq(is_complete.into(), C::N::one()) - .then(|builder| { - builder.assert_ext_eq(cumulative_sum, C::EF::zero().cons()); - }); - } -} - pub type RecursiveVerifierConstraintFolder<'a, C> = GenericVerifierConstraintFolder< 'a, ::F, @@ -175,6 +134,14 @@ where .. } = proof; + // Extract public values. + let mut pv_elements = Vec::new(); + for i in 0..machine.num_pv_elts() { + let element = builder.get(&proof.public_values, i); + pv_elements.push(element); + } + let public_values = PublicValues::>, Felt<_>>::from_vec(pv_elements); + let ShardCommitmentVariable { main_commit, permutation_commit, @@ -344,18 +311,43 @@ where builder.cycle_tracker("stage-d-verify-pcs"); builder.cycle_tracker("stage-e-verify-constraints"); + + let shard_bits = builder.num2bits_f(public_values.shard); + let shard = builder.bits2num_v(&shard_bits); for (i, chip) in machine.chips().iter().enumerate() { tracing::debug!("verifying constraints for chip: {}", chip.name()); let index = builder.get(&proof.sorted_idxs, i); - if chip.preprocessed_width() > 0 { + if chip.name() == "CPU" { builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } - if chip.name() == "CPU" { + if chip.preprocessed_width() > 0 { builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } + if chip.name() == "MemoryInit" { + builder.if_eq(shard, C::N::one()).then_or_else( + |builder| { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + }, + |builder| { + builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); + }, + ); + } + + if chip.name() == "MemoryFinalize" { + builder.if_eq(shard, C::N::one()).then_or_else( + |builder| { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + }, + |builder| { + builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); + }, + ); + } + builder .if_ne(index, C::N::from_canonical_usize(EMPTY)) .then(|builder| { diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index 6f0e06ff40..991d454113 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -20,7 +20,7 @@ sp1-prover = { path = "../prover" } sp1-core = { path = "../core" } futures = "0.3.30" bincode = "1.3.3" -tokio = { version = "1.37.0", features = ["full"] } +tokio = { version = "1.38.0", features = ["full"] } p3-matrix = { workspace = true } p3-commit = { workspace = true } p3-field = { workspace = true } @@ -29,21 +29,22 @@ tracing = "0.1.40" hex = "0.4.3" log = "0.4.21" axum = "=0.7.5" -alloy-sol-types = { version = "0.7.0", optional = true } +alloy-sol-types = { version = "0.7.6", optional = true } sha2 = "0.10.8" dirs = "5.0.1" tempfile = "3.10.1" num-bigint = "0.4.5" cfg-if = "1.0" ethers = { version = "2", default-features = false } -strum_macros = "0.26.2" +strum_macros = "0.26.4" strum = "0.26.2" +thiserror = "1.0.61" [features] default = ["network"] neon = ["sp1-core/neon"] -plonk = ["sp1-prover/plonk"] +native-gnark = ["sp1-prover/native-gnark"] # TODO: Once alloy has a 1.* release, we can likely remove this feature flag, as there will be less # dependency resolution issues. network = ["dep:alloy-sol-types"] diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index a48ef5e248..55fe50d0d7 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -21,6 +21,7 @@ pub mod utils { } use cfg_if::cfg_if; +pub use provers::SP1VerificationError; use std::{env, fmt::Debug, fs::File, path::Path}; use anyhow::{Ok, Result}; @@ -31,6 +32,7 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; use sp1_core::{ runtime::ExecutionReport, stark::{MachineVerificationError, ShardProof}, + SP1_CIRCUIT_VERSION, }; pub use sp1_prover::{ CoreSC, HashableKey, InnerSC, OuterSC, PlonkBn254Proof, SP1Prover, SP1ProvingKey, @@ -51,6 +53,7 @@ pub struct SP1ProofWithPublicValues

{ pub proof: P, pub stdin: SP1Stdin, pub public_values: SP1PublicValues, + pub sp1_version: String, } /// A [SP1ProofWithPublicValues] generated with [ProverClient::prove]. @@ -102,7 +105,7 @@ impl ProverClient { panic!("network feature is not enabled") } } - }, + } _ => panic!( "invalid value for SP1_PROVER enviroment variable: expected 'local', 'mock', or 'network'" ), @@ -169,6 +172,13 @@ impl ProverClient { } } + /// Gets the current version of the SP1 zkVM. + /// + /// Note: This is not the same as the version of the SP1 SDK. + pub fn version(&self) -> String { + SP1_CIRCUIT_VERSION.to_string() + } + /// Executes the given program on the given input (without generating a proof). /// /// Returns the public values and execution report of the program after it has been executed. @@ -331,7 +341,7 @@ impl ProverClient { &self, proof: &SP1Proof, vkey: &SP1VerifyingKey, - ) -> Result<(), SP1ProofVerificationError> { + ) -> Result<(), SP1VerificationError> { self.prover.verify(proof, vkey) } @@ -363,7 +373,7 @@ impl ProverClient { &self, proof: &SP1CompressedProof, vkey: &SP1VerifyingKey, - ) -> Result<()> { + ) -> Result<(), SP1VerificationError> { self.prover.verify_compressed(proof, vkey) } @@ -393,7 +403,11 @@ impl ProverClient { /// // Verify the proof. /// client.verify_plonk(&proof, &vk).unwrap(); /// ``` - pub fn verify_plonk(&self, proof: &SP1PlonkBn254Proof, vkey: &SP1VerifyingKey) -> Result<()> { + pub fn verify_plonk( + &self, + proof: &SP1PlonkBn254Proof, + vkey: &SP1VerifyingKey, + ) -> Result<(), SP1VerificationError> { self.prover.verify_plonk(proof, vkey) } } @@ -419,8 +433,13 @@ impl SP1ProofWithPublicValues

String { - format!("0x{}", self.proof.encoded_proof.clone()) + format!( + "0x{}{}", + hex::encode(&self.proof.plonk_vkey_hash[..4]), + &self.proof.encoded_proof + ) } } diff --git a/sdk/src/network/prover.rs b/sdk/src/network/prover.rs index c4870fcacf..ba54cb6d9e 100644 --- a/sdk/src/network/prover.rs +++ b/sdk/src/network/prover.rs @@ -3,16 +3,15 @@ use std::{env, time::Duration}; use crate::proto::network::ProofMode; use crate::{ network::client::{NetworkClient, DEFAULT_PROVER_NETWORK_RPC}, - proto::network::{ProofStatus, TransactionStatus}, + proto::network::ProofStatus, Prover, }; use crate::{SP1CompressedProof, SP1PlonkBn254Proof, SP1Proof, SP1ProvingKey, SP1VerifyingKey}; -use anyhow::{Context, Result}; +use anyhow::Result; use serde::de::DeserializeOwned; -use sp1_prover::install::PLONK_BN254_ARTIFACTS_COMMIT; use sp1_prover::utils::block_on; -use sp1_prover::{SP1Prover, SP1Stdin}; -use tokio::{runtime, time::sleep}; +use sp1_prover::{SP1Prover, SP1Stdin, SP1_CIRCUIT_VERSION}; +use tokio::time::sleep; use crate::provers::{LocalProver, ProverType}; @@ -23,23 +22,32 @@ pub struct NetworkProver { } impl NetworkProver { - /// Creates a new [NetworkProver]. + /// Creates a new [NetworkProver] with the private key set in `SP1_PRIVATE_KEY`. pub fn new() -> Self { let private_key = env::var("SP1_PRIVATE_KEY") .unwrap_or_else(|_| panic!("SP1_PRIVATE_KEY must be set for remote proving")); + Self::new_from_key(&private_key) + } + + /// Creates a new [NetworkProver] with the given private key. + pub fn new_from_key(private_key: &str) -> Self { + let version = SP1_CIRCUIT_VERSION; + log::info!("Client circuit version: {}", version); + let local_prover = LocalProver::new(); Self { - client: NetworkClient::new(&private_key), + client: NetworkClient::new(private_key), local_prover, } } - pub async fn prove_async( + /// Requests a proof from the prover network, returning the proof ID. + pub async fn request_proof( &self, elf: &[u8], stdin: SP1Stdin, mode: ProofMode, - ) -> Result

{ + ) -> Result { let client = &self.client; let skip_simulation = env::var("SKIP_SIMULATION") @@ -56,22 +64,25 @@ impl NetworkProver { log::info!("Skipping simulation"); } - let version = PLONK_BN254_ARTIFACTS_COMMIT; - log::info!("Client version {}", version); - + let version = SP1_CIRCUIT_VERSION; let proof_id = client.create_proof(elf, &stdin, mode, version).await?; log::info!("Created {}", proof_id); if NetworkClient::rpc_url() == DEFAULT_PROVER_NETWORK_RPC { log::info!( "View in explorer: https://explorer.succinct.xyz/{}", - proof_id.split('_').last().unwrap_or(&proof_id) + proof_id ); } + Ok(proof_id) + } + /// Waits for a proof to be generated and returns the proof. + pub async fn wait_proof(&self, proof_id: &str) -> Result

{ + let client = &self.client; let mut is_claimed = false; loop { - let (status, maybe_proof) = client.get_proof_status::

(&proof_id).await?; + let (status, maybe_proof) = client.get_proof_status::

(proof_id).await?; match status.status() { ProofStatus::ProofFulfilled => { @@ -95,68 +106,10 @@ impl NetworkProver { } } - #[allow(dead_code)] - /// Remotely relay a proof to a set of chains with their callback contracts. - pub fn remote_relay( - &self, - proof_id: &str, - chain_ids: Vec, - callbacks: Vec<[u8; 20]>, - callback_datas: Vec>, - ) -> Result> { - let rt = runtime::Runtime::new()?; - rt.block_on(async { - let client = &self.client; - - let verifier = NetworkClient::get_sp1_verifier_address(); - - let mut tx_details = Vec::new(); - for ((i, callback), callback_data) in - callbacks.iter().enumerate().zip(callback_datas.iter()) - { - if let Some(&chain_id) = chain_ids.get(i) { - let tx_id = client - .relay_proof(proof_id, chain_id, verifier, *callback, callback_data) - .await - .with_context(|| format!("Failed to relay proof to chain {}", chain_id))?; - tx_details.push((tx_id.clone(), chain_id)); - } - } - - let mut tx_ids = Vec::new(); - for (tx_id, chain_id) in tx_details.iter() { - loop { - let (status_res, maybe_tx_hash, maybe_simulation_url) = - client.get_relay_status(tx_id).await?; - - match status_res.status() { - TransactionStatus::TransactionFinalized => { - println!( - "Relaying to chain {} succeeded with tx hash: {:?}", - chain_id, - maybe_tx_hash.as_deref().unwrap_or("None") - ); - tx_ids.push(tx_id.clone()); - break; - } - TransactionStatus::TransactionFailed - | TransactionStatus::TransactionTimedout => { - return Err(anyhow::anyhow!( - "Relaying to chain {} failed with tx hash: {:?}, simulation url: {:?}", - chain_id, - maybe_tx_hash.as_deref().unwrap_or("None"), - maybe_simulation_url.as_deref().unwrap_or("None") - )); - } - _ => { - sleep(Duration::from_secs(5)).await; - } - } - } - } - - Ok(tx_ids) - }) + /// Requests a proof from the prover network and waits for it to be generated. + pub async fn prove(&self, elf: &[u8], stdin: SP1Stdin) -> Result

{ + let proof_id = self.request_proof(elf, stdin, P::PROOF_MODE).await?; + self.wait_proof(&proof_id).await } } @@ -174,15 +127,15 @@ impl Prover for NetworkProver { } fn prove(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - block_on(self.prove_async(&pk.elf, stdin, ProofMode::Core)) + block_on(self.prove(&pk.elf, stdin)) } fn prove_compressed(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - block_on(self.prove_async(&pk.elf, stdin, ProofMode::Compressed)) + block_on(self.prove(&pk.elf, stdin)) } fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - block_on(self.prove_async(&pk.elf, stdin, ProofMode::Plonk)) + block_on(self.prove(&pk.elf, stdin)) } } @@ -191,3 +144,20 @@ impl Default for NetworkProver { Self::new() } } + +/// A deserializable proof struct that has an associated ProofMode. +pub trait ProofType: DeserializeOwned { + const PROOF_MODE: ProofMode; +} + +impl ProofType for SP1Proof { + const PROOF_MODE: ProofMode = ProofMode::Core; +} + +impl ProofType for SP1CompressedProof { + const PROOF_MODE: ProofMode = ProofMode::Compressed; +} + +impl ProofType for SP1PlonkBn254Proof { + const PROOF_MODE: ProofMode = ProofMode::Plonk; +} diff --git a/sdk/src/provers/local.rs b/sdk/src/provers/local.rs index 5eadf79a34..d3c60cdac1 100644 --- a/sdk/src/provers/local.rs +++ b/sdk/src/provers/local.rs @@ -1,5 +1,4 @@ use anyhow::Result; -use cfg_if::cfg_if; use sp1_prover::{SP1Prover, SP1Stdin}; use crate::{ @@ -41,6 +40,7 @@ impl Prover for LocalProver { proof: proof.proof.0, stdin: proof.stdin, public_values: proof.public_values, + sp1_version: self.version().to_string(), }) } @@ -53,39 +53,35 @@ impl Prover for LocalProver { proof: reduce_proof.proof, stdin, public_values, + sp1_version: self.version().to_string(), }) } - #[allow(unused)] fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result { - cfg_if! { - if #[cfg(feature = "plonk")] { - - let proof = self.prover.prove_core(pk, &stdin)?; - let deferred_proofs = stdin.proofs.iter().map(|p| p.0.clone()).collect(); - let public_values = proof.public_values.clone(); - let reduce_proof = self.prover.compress(&pk.vk, proof, deferred_proofs)?; - let compress_proof = self.prover.shrink(reduce_proof)?; - let outer_proof = self.prover.wrap_bn254(compress_proof)?; + let proof = self.prover.prove_core(pk, &stdin)?; + let deferred_proofs = stdin.proofs.iter().map(|p| p.0.clone()).collect(); + let public_values = proof.public_values.clone(); + let reduce_proof = self.prover.compress(&pk.vk, proof, deferred_proofs)?; + let compress_proof = self.prover.shrink(reduce_proof)?; + let outer_proof = self.prover.wrap_bn254(compress_proof)?; - let plonk_bn254_aritfacts = if sp1_prover::build::sp1_dev_mode() { - sp1_prover::build::try_build_plonk_bn254_artifacts_dev( - &self.prover.wrap_vk, - &outer_proof.proof, - ) - } else { - sp1_prover::build::try_install_plonk_bn254_artifacts() - }; - let proof = self.prover.wrap_plonk_bn254(outer_proof, &plonk_bn254_aritfacts); - Ok(SP1ProofWithPublicValues { - proof, - stdin, - public_values, - }) - } else { - panic!("plonk feature not enabled") - } - } + let plonk_bn254_aritfacts = if sp1_prover::build::sp1_dev_mode() { + sp1_prover::build::try_build_plonk_bn254_artifacts_dev( + &self.prover.wrap_vk, + &outer_proof.proof, + ) + } else { + sp1_prover::build::try_install_plonk_bn254_artifacts() + }; + let proof = self + .prover + .wrap_plonk_bn254(outer_proof, &plonk_bn254_aritfacts); + Ok(SP1ProofWithPublicValues { + proof, + stdin, + public_values, + sp1_version: self.version().to_string(), + }) } } diff --git a/sdk/src/provers/mock.rs b/sdk/src/provers/mock.rs index f0d75b3101..8c5859ea82 100644 --- a/sdk/src/provers/mock.rs +++ b/sdk/src/provers/mock.rs @@ -1,7 +1,7 @@ #![allow(unused_variables)] use crate::{ - Prover, SP1CompressedProof, SP1PlonkBn254Proof, SP1Proof, SP1ProofVerificationError, - SP1ProofWithPublicValues, SP1ProvingKey, SP1VerifyingKey, + Prover, SP1CompressedProof, SP1PlonkBn254Proof, SP1Proof, SP1ProofWithPublicValues, + SP1ProvingKey, SP1VerificationError, SP1VerifyingKey, }; use anyhow::Result; use p3_field::PrimeField; @@ -43,6 +43,7 @@ impl Prover for MockProver { proof: vec![], stdin, public_values, + sp1_version: self.version().to_string(), }) } @@ -64,9 +65,11 @@ impl Prover for MockProver { ], encoded_proof: "".to_string(), raw_proof: "".to_string(), + plonk_vkey_hash: [0; 32], }, stdin, public_values, + sp1_version: self.version().to_string(), }) } @@ -74,7 +77,7 @@ impl Prover for MockProver { &self, _proof: &SP1Proof, _vkey: &SP1VerifyingKey, - ) -> Result<(), SP1ProofVerificationError> { + ) -> Result<(), SP1VerificationError> { Ok(()) } @@ -82,12 +85,17 @@ impl Prover for MockProver { &self, _proof: &SP1CompressedProof, _vkey: &SP1VerifyingKey, - ) -> Result<()> { + ) -> Result<(), SP1VerificationError> { Ok(()) } - fn verify_plonk(&self, proof: &SP1PlonkBn254Proof, vkey: &SP1VerifyingKey) -> Result<()> { - verify_plonk_bn254_public_inputs(vkey, &proof.public_values, &proof.proof.public_inputs)?; + fn verify_plonk( + &self, + proof: &SP1PlonkBn254Proof, + vkey: &SP1VerifyingKey, + ) -> Result<(), SP1VerificationError> { + verify_plonk_bn254_public_inputs(vkey, &proof.public_values, &proof.proof.public_inputs) + .map_err(SP1VerificationError::Plonk)?; Ok(()) } } diff --git a/sdk/src/provers/mod.rs b/sdk/src/provers/mod.rs index 4c731c43d4..f682601a56 100644 --- a/sdk/src/provers/mod.rs +++ b/sdk/src/provers/mod.rs @@ -6,12 +6,15 @@ use anyhow::Result; pub use local::LocalProver; pub use mock::MockProver; use sp1_core::stark::MachineVerificationError; +use sp1_core::SP1_CIRCUIT_VERSION; use sp1_prover::CoreSC; +use sp1_prover::InnerSC; use sp1_prover::SP1CoreProofData; use sp1_prover::SP1Prover; use sp1_prover::SP1ReduceProof; use sp1_prover::{SP1ProvingKey, SP1Stdin, SP1VerifyingKey}; use strum_macros::EnumString; +use thiserror::Error; /// The type of prover. #[derive(Debug, PartialEq, EnumString)] @@ -21,12 +24,28 @@ pub enum ProverType { Network, } +#[derive(Error, Debug)] +pub enum SP1VerificationError { + #[error("Version mismatch")] + VersionMismatch(String), + #[error("Core machine verification error: {0}")] + Core(MachineVerificationError), + #[error("Recursion verification error: {0}")] + Recursion(MachineVerificationError), + #[error("Plonk verification error: {0}")] + Plonk(anyhow::Error), +} + /// An implementation of [crate::ProverClient]. pub trait Prover: Send + Sync { fn id(&self) -> ProverType; fn sp1_prover(&self) -> &SP1Prover; + fn version(&self) -> &str { + SP1_CIRCUIT_VERSION + } + fn setup(&self, elf: &[u8]) -> (SP1ProvingKey, SP1VerifyingKey); /// Prove the execution of a RISCV ELF with the given inputs. @@ -39,17 +58,28 @@ pub trait Prover: Send + Sync { fn prove_plonk(&self, pk: &SP1ProvingKey, stdin: SP1Stdin) -> Result; /// Verify that an SP1 proof is valid given its vkey and metadata. - fn verify( - &self, - proof: &SP1Proof, - vkey: &SP1VerifyingKey, - ) -> Result<(), MachineVerificationError> { + fn verify(&self, proof: &SP1Proof, vkey: &SP1VerifyingKey) -> Result<(), SP1VerificationError> { + if proof.sp1_version != self.version() { + return Err(SP1VerificationError::VersionMismatch( + proof.sp1_version.clone(), + )); + } self.sp1_prover() .verify(&SP1CoreProofData(proof.proof.clone()), vkey) + .map_err(SP1VerificationError::Core) } /// Verify that a compressed SP1 proof is valid given its vkey and metadata. - fn verify_compressed(&self, proof: &SP1CompressedProof, vkey: &SP1VerifyingKey) -> Result<()> { + fn verify_compressed( + &self, + proof: &SP1CompressedProof, + vkey: &SP1VerifyingKey, + ) -> Result<(), SP1VerificationError> { + if proof.sp1_version != self.version() { + return Err(SP1VerificationError::VersionMismatch( + proof.sp1_version.clone(), + )); + } self.sp1_prover() .verify_compressed( &SP1ReduceProof { @@ -57,12 +87,21 @@ pub trait Prover: Send + Sync { }, vkey, ) - .map_err(|e| e.into()) + .map_err(SP1VerificationError::Recursion) } /// Verify that a SP1 PLONK proof is valid. Verify that the public inputs of the PlonkBn254 proof match /// the hash of the VK and the committed public values of the SP1ProofWithPublicValues. - fn verify_plonk(&self, proof: &SP1PlonkBn254Proof, vkey: &SP1VerifyingKey) -> Result<()> { + fn verify_plonk( + &self, + proof: &SP1PlonkBn254Proof, + vkey: &SP1VerifyingKey, + ) -> Result<(), SP1VerificationError> { + if proof.sp1_version != self.version() { + return Err(SP1VerificationError::VersionMismatch( + proof.sp1_version.clone(), + )); + } let sp1_prover = self.sp1_prover(); let plonk_bn254_aritfacts = if sp1_prover::build::sp1_dev_mode() { @@ -70,12 +109,14 @@ pub trait Prover: Send + Sync { } else { sp1_prover::build::try_install_plonk_bn254_artifacts() }; - sp1_prover.verify_plonk_bn254( - &proof.proof, - vkey, - &proof.public_values, - &plonk_bn254_aritfacts, - )?; + sp1_prover + .verify_plonk_bn254( + &proof.proof, + vkey, + &proof.public_values, + &plonk_bn254_aritfacts, + ) + .map_err(SP1VerificationError::Plonk)?; Ok(()) } diff --git a/tests/blake3-compress/Cargo.toml b/tests/blake3-compress/Cargo.toml deleted file mode 100644 index e5987407cc..0000000000 --- a/tests/blake3-compress/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[workspace] -[package] -version = "0.1.0" -name = "blake3-compress-test" -edition = "2021" - -[dependencies] -sp1-zkvm = { path = "../../zkvm/entrypoint" } diff --git a/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf b/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf deleted file mode 100755 index 4e0fee0235..0000000000 Binary files a/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf and /dev/null differ diff --git a/tests/blake3-compress/src/main.rs b/tests/blake3-compress/src/main.rs deleted file mode 100644 index 6bbee4916f..0000000000 --- a/tests/blake3-compress/src/main.rs +++ /dev/null @@ -1,42 +0,0 @@ -#![no_main] -sp1_zkvm::entrypoint!(main); - -extern "C" { - fn syscall_blake3_compress_inner(p: *mut u32, q: *const u32); -} - -pub fn main() { - // The input message and state are simply 0, 1, ..., 95 followed by some fixed constants. - for _i in 0..10 { - let input_message: [u8; 64] = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, - 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, - 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - ]; - - let mut input_state: [u8; 64] = [ - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, - 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 103, 230, 9, 106, 133, 174, 103, 187, 114, 243, - 110, 60, 58, 245, 79, 165, 96, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 97, 0, 0, 0, - ]; - - unsafe { - syscall_blake3_compress_inner( - input_state.as_mut_ptr() as *mut u32, - input_message.as_ptr() as *const u32, - ); - } - - // The expected output state is the result of compress_inner. - let output_state: [u8; 64] = [ - 239, 181, 94, 129, 58, 124, 80, 104, 126, 210, 5, 157, 255, 58, 238, 89, 252, 106, 170, - 12, 233, 56, 58, 31, 215, 16, 105, 97, 11, 229, 238, 73, 6, 79, 155, 180, 197, 73, 116, - 0, 127, 22, 16, 39, 116, 174, 85, 5, 61, 94, 87, 6, 236, 10, 36, 238, 119, 171, 207, - 171, 189, 216, 43, 250, - ]; - - assert_eq!(input_state, output_state); - } - - println!("done"); -} diff --git a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf index fce1bf9ffe..6e2c7e6866 100755 Binary files a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-add/src/main.rs b/tests/bls12381-add/src/main.rs index 874e9f066e..681cf39afe 100644 --- a/tests/bls12381-add/src/main.rs +++ b/tests/bls12381-add/src/main.rs @@ -6,44 +6,48 @@ extern "C" { } pub fn main() { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let mut a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, - 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, - 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, - 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, - 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, - ]; + for _ in 0..4 { + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let mut a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, + 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, + 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, + 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, + 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, + 179, 8, + ]; - // 2 * generator. - // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 - // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 - let b: [u8; 96] = [ - 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, 195, - 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, 136, 70, - 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, 22, 205, - 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, 62, 186, 34, - 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, 22, - ]; + // 2 * generator. + // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 + // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 + let b: [u8; 96] = [ + 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, + 195, 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, + 136, 70, 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, + 22, 205, 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, + 62, 186, 34, 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, + 22, + ]; - unsafe { - syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, - 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, - 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, - 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, - 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, + 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, + 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, + 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, + 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf index 818954dc49..3a8f2e1872 100755 Binary files a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-decompress/src/main.rs b/tests/bls12381-decompress/src/main.rs index 3e93a099f2..a359fa9f7c 100644 --- a/tests/bls12381-decompress/src/main.rs +++ b/tests/bls12381-decompress/src/main.rs @@ -7,19 +7,22 @@ extern "C" { pub fn main() { let compressed_key: [u8; 48] = sp1_zkvm::io::read_vec().try_into().unwrap(); - let mut decompressed_key: [u8; 96] = [0u8; 96]; - decompressed_key[..48].copy_from_slice(&compressed_key); + for _ in 0..4 { + let mut decompressed_key: [u8; 96] = [0u8; 96]; - println!("before: {:?}", decompressed_key); + decompressed_key[..48].copy_from_slice(&compressed_key); - let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; - decompressed_key[0] &= 0b_0001_1111; + println!("before: {:?}", decompressed_key); - unsafe { - syscall_bls12381_decompress(&mut decompressed_key, is_odd); - } - println!("after: {:?}", decompressed_key); + let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; + decompressed_key[0] &= 0b_0001_1111; + + unsafe { + syscall_bls12381_decompress(&mut decompressed_key, is_odd); + } - sp1_zkvm::io::commit_slice(&decompressed_key); + println!("after: {:?}", decompressed_key); + sp1_zkvm::io::commit_slice(&decompressed_key); + } } diff --git a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf index 5c4706b8fa..50470172a8 100755 Binary files a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf index 313a6226f3..d1fe6cdf6f 100755 Binary files a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/src/main.rs b/tests/bls12381-mul/src/main.rs index 9f90906e92..89169660a3 100644 --- a/tests/bls12381-mul/src/main.rs +++ b/tests/bls12381-mul/src/main.rs @@ -6,39 +6,42 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, - 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, - 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, - 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, - 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, - 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, - 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, - 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, - 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, + 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, + 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, + 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, + 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, + 179, 8, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, + 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, + 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, + 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, + 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf index a45b52cd9d..a55b917d17 100755 Binary files a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-add/src/main.rs b/tests/bn254-add/src/main.rs index 7e164663da..406681d656 100644 --- a/tests/bn254-add/src/main.rs +++ b/tests/bn254-add/src/main.rs @@ -6,40 +6,42 @@ extern "C" { } pub fn main() { - // generator. - // 1 - // 2 - let mut a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, - ]; + for _ in 0..4 { + // generator. + // 1 + // 2 + let mut a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]; - // 2 * generator. - // 1368015179489954701390400359078579693043519447331113978918064868415326638035 - // 9918110051302171585080402603319702774565515993150576347155970296011118125764 - let b: [u8; 64] = [ - 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, - 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, 255, - 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, 124, 10, - 14, 140, 115, 237, 21, - ]; + // 2 * generator. + // 1368015179489954701390400359078579693043519447331113978918064868415326638035 + // 9918110051302171585080402603319702774565515993150576347155970296011118125764 + let b: [u8; 64] = [ + 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, + 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, + 255, 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, + 124, 10, 14, 140, 115, 237, 21, + ]; - unsafe { - syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, - 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, - 224, 190, 153, 183, 42, - ]; + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, + 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, + 148, 72, 224, 190, 153, 183, 42, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf index 2c7bcb6231..b571be7344 100755 Binary files a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf index a414416de7..dd1506ddc7 100755 Binary files a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/src/main.rs b/tests/bn254-mul/src/main.rs index 841de5e4d0..3086c3806f 100644 --- a/tests/bn254-mul/src/main.rs +++ b/tests/bn254-mul/src/main.rs @@ -6,36 +6,38 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 1 - // 2 - let a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, - 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, - 224, 190, 153, 183, 42, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 1 + // 2 + let a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, + 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, + 148, 72, 224, 190, 153, 183, 42, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf index ed3121d5d8..6e2531ad0b 100755 Binary files a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf and b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf index d75d1642e9..912b0a3e2e 100755 Binary files a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf and b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf index 1f79b12f49..5916c8a800 100755 Binary files a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/src/main.rs b/tests/ed-add/src/main.rs index 057aea4823..3deafa0c76 100644 --- a/tests/ed-add/src/main.rs +++ b/tests/ed-add/src/main.rs @@ -6,37 +6,40 @@ extern "C" { } pub fn main() { - // 90393249858788985237231628593243673548167146579814268721945474994541877372611 - // 33321104029277118100578831462130550309254424135206412570121538923759338004303 - let mut a: [u8; 64] = [ - 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, 66, - 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, 179, 122, - 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, 133, 92, 187, - 130, 33, 152, 19, 171, 73, - ]; + for _ in 0..4 { + // 90393249858788985237231628593243673548167146579814268721945474994541877372611 + // 33321104029277118100578831462130550309254424135206412570121538923759338004303 + let mut a: [u8; 64] = [ + 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, + 66, 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, + 179, 122, 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, + 133, 92, 187, 130, 33, 152, 19, 171, 73, + ]; - // 61717728572175158701898635111983295176935961585742968051419350619945173564869 - // 28137966556353620208933066709998005335145594788896528644015312259959272398451 - let b: [u8; 64] = [ - 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, 188, - 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, 200, 117, - 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, 167, 131, - 199, 47, 82, 134, 53, 62, - ]; + // 61717728572175158701898635111983295176935961585742968051419350619945173564869 + // 28137966556353620208933066709998005335145594788896528644015312259959272398451 + let b: [u8; 64] = [ + 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, + 188, 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, + 200, 117, 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, + 167, 131, 199, 47, 82, 134, 53, 62, + ]; - unsafe { - syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } + + // 36213413123116753589144482590359479011148956763279542162278577842046663495729 + // 17093345531692682197799066694073110060588941459686871373458223451938707761683 + let c: [u8; 64] = [ + 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, + 178, 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, + 72, 56, 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, + 253, 125, 44, 80, 222, 209, 159, 125, 202, 37, + ]; - // 36213413123116753589144482590359479011148956763279542162278577842046663495729 - // 17093345531692682197799066694073110060588941459686871373458223451938707761683 - let c: [u8; 64] = [ - 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, 178, - 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, 72, 56, - 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, 253, 125, 44, - 80, 222, 209, 159, 125, 202, 37, - ]; + assert_eq!(a, c); + } - assert_eq!(a, c); println!("done"); } diff --git a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf index 10bbf5e06e..233f1ab1cb 100755 Binary files a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-decompress/src/main.rs b/tests/ed-decompress/src/main.rs index 32f4eef659..0b6929dde4 100644 --- a/tests/ed-decompress/src/main.rs +++ b/tests/ed-decompress/src/main.rs @@ -8,26 +8,28 @@ extern "C" { } pub fn main() { - let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); + for _ in 0..4 { + let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); - let mut decompressed = [0_u8; 64]; - decompressed[32..].copy_from_slice(&pub_bytes); + let mut decompressed = [0_u8; 64]; + decompressed[32..].copy_from_slice(&pub_bytes); - println!("before: {:?}", decompressed); + println!("before: {:?}", decompressed); - unsafe { - syscall_ed_decompress(decompressed.as_mut_ptr()); - } + unsafe { + syscall_ed_decompress(decompressed.as_mut_ptr()); + } - let expected: [u8; 64] = [ - 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, 250, - 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, 94, 86, 59, - 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, 235, 248, 25, 104, - 52, 103, 226, 63, - ]; + let expected: [u8; 64] = [ + 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, + 250, 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, + 94, 86, 59, 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, + 235, 248, 25, 104, 52, 103, 226, 63, + ]; - assert_eq!(decompressed, expected); + assert_eq!(decompressed, expected); + println!("after: {:?}", decompressed); + } - println!("after: {:?}", decompressed); println!("done"); } diff --git a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf index 88c83e3c0a..5f149617c0 100755 Binary files a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf and b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/fibonacci/Cargo.toml b/tests/fibonacci/Cargo.toml index 31f8878487..231775d5be 100644 --- a/tests/fibonacci/Cargo.toml +++ b/tests/fibonacci/Cargo.toml @@ -1,7 +1,7 @@ [workspace] [package] version = "0.1.0" -name = "fibonacci-program" +name = "fibonacci-program-tests" edition = "2021" [dependencies] diff --git a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf index 1c59449d83..7a61102c17 100755 Binary files a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf and b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf index 69fc40b116..ac7a2fc293 100755 Binary files a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf and b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf index a843a07799..15dee99151 100755 Binary files a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf index 48e4965b34..311da32c16 100755 Binary files a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/panic/elf/riscv32im-succinct-zkvm-elf b/tests/panic/elf/riscv32im-succinct-zkvm-elf index e68a4a4dc9..8debb2189a 100755 Binary files a/tests/panic/elf/riscv32im-succinct-zkvm-elf and b/tests/panic/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/blake3-compress/Cargo.lock b/tests/rand/Cargo.lock similarity index 88% rename from tests/blake3-compress/Cargo.lock rename to tests/rand/Cargo.lock index b45f827d8d..6811bed258 100644 --- a/tests/blake3-compress/Cargo.lock +++ b/tests/rand/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "25bdb32cbbdce2b519a9cd7df3a678443100e265d5e25ca763b7572a5104f5f3" [[package]] name = "arrayvec" @@ -16,9 +16,9 @@ checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "autocfg" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" [[package]] name = "base16ct" @@ -53,13 +53,6 @@ dependencies = [ "wyz", ] -[[package]] -name = "blake3-compress-test" -version = "0.1.0" -dependencies = [ - "sp1-zkvm", -] - [[package]] name = "block-buffer" version = "0.10.4" @@ -120,9 +113,9 @@ dependencies = [ [[package]] name = "der" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" +checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" dependencies = [ "const-oid", "zeroize", @@ -243,9 +236,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.3" +version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" [[package]] name = "hex" @@ -297,11 +290,17 @@ dependencies = [ "signature", ] +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + [[package]] name = "libc" -version = "0.2.155" +version = "0.2.154" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" [[package]] name = "libm" @@ -396,9 +395,9 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "parity-scale-codec" -version = "3.6.9" +version = "3.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881331e34fa842a2fb61cc2db9643a8fedc615e47cfcc52597d1af0db9a7e8fe" +checksum = "306800abfa29c7f16596b5970a588435e3d5b3149683d00c12b699cc19f895ee" dependencies = [ "arrayvec", "byte-slice-cast", @@ -408,11 +407,11 @@ dependencies = [ [[package]] name = "parity-scale-codec-derive" -version = "3.6.9" +version = "3.6.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be30eaf4b0a9fba5336683b38de57bb86d179a35862ba6bfcf57625d006bde5b" +checksum = "d830939c76d294956402033aee57a6da7b438f2294eb94864c37b0569053a42c" dependencies = [ - "proc-macro-crate 2.0.2", + "proc-macro-crate", "proc-macro2", "quote", "syn 1.0.109", @@ -436,38 +435,27 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro-crate" -version = "1.3.1" +version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" +checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284" dependencies = [ - "once_cell", - "toml_edit 0.19.15", -] - -[[package]] -name = "proc-macro-crate" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b00f26d3400549137f92511a46ac1cd8ce37cb5598a96d382381458b992a5d24" -dependencies = [ - "toml_datetime", - "toml_edit 0.20.2", + "toml_edit", ] [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.35" +version = "1.0.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" dependencies = [ "proc-macro2", ] @@ -489,6 +477,15 @@ dependencies = [ "rand_core", ] +[[package]] +name = "rand-test" +version = "0.1.0" +dependencies = [ + "rand", + "sp1-derive", + "sp1-zkvm", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -520,9 +517,9 @@ dependencies = [ [[package]] name = "scale-info" -version = "2.11.2" +version = "2.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c453e59a955f81fb62ee5d596b450383d699f152d350e9d23a0db2adb78e4c0" +checksum = "eca070c12893629e2cc820a9761bedf6ce1dcddc9852984d1dc734b8bd9bd024" dependencies = [ "cfg-if", "derive_more", @@ -532,11 +529,11 @@ dependencies = [ [[package]] name = "scale-info-derive" -version = "2.11.2" +version = "2.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18cf6c6447f813ef19eb450e985bcce6705f9ce7660db221b59093d15c79c4b7" +checksum = "2d35494501194174bda522a32605929eefc9ecf7e0a326c26db1fdd85881eb62" dependencies = [ - "proc-macro-crate 1.3.1", + "proc-macro-crate", "proc-macro2", "quote", "syn 1.0.109", @@ -558,22 +555,22 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.203" +version = "1.0.201" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" +checksum = "780f1cebed1629e4753a1a38a3c72d30b97ec044f0aef68cb26650a3c5cf363c" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.203" +version = "1.0.201" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" +checksum = "c5e405930b9796f1c00bee880d03fc7e0bb4b9a11afc776885ffe84320da2865" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.63", ] [[package]] @@ -607,6 +604,15 @@ dependencies = [ "scale-info", ] +[[package]] +name = "sp1-derive" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "sp1-precompiles" version = "0.1.0" @@ -631,6 +637,7 @@ dependencies = [ "cfg-if", "getrandom", "k256", + "lazy_static", "libm", "once_cell", "rand", @@ -668,9 +675,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "bf5be731623ca1a1fb7d8be6f261a3be6d3e2337b8a1f97be944d020c8fcb704" dependencies = [ "proc-macro2", "quote", @@ -685,26 +692,15 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "toml_datetime" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" - -[[package]] -name = "toml_edit" -version = "0.19.15" +version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap", - "toml_datetime", - "winnow", -] +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" [[package]] name = "toml_edit" -version = "0.20.2" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" +checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" dependencies = [ "indexmap", "toml_datetime", diff --git a/tests/uint256-div/Cargo.toml b/tests/rand/Cargo.toml similarity index 70% rename from tests/uint256-div/Cargo.toml rename to tests/rand/Cargo.toml index 495c8bc9a6..2ada8cf376 100644 --- a/tests/uint256-div/Cargo.toml +++ b/tests/rand/Cargo.toml @@ -1,11 +1,10 @@ [workspace] [package] -name = "uint256-div-test" +name = "rand-test" version = "0.1.0" edition = "2021" [dependencies] -rand = "0.8" -num = { version = "0.4.1" } sp1-zkvm = { path = "../../zkvm/entrypoint" } sp1-derive = { path = "../../derive" } +rand = "0.8.5" diff --git a/tests/rand/elf/riscv32im-succinct-zkvm-elf b/tests/rand/elf/riscv32im-succinct-zkvm-elf new file mode 100755 index 0000000000..f5ddaebdc2 Binary files /dev/null and b/tests/rand/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/rand/src/main.rs b/tests/rand/src/main.rs new file mode 100644 index 0000000000..b9a5cac6da --- /dev/null +++ b/tests/rand/src/main.rs @@ -0,0 +1,12 @@ +#![no_main] +sp1_zkvm::entrypoint!(main); + +use rand::Rng; + +pub fn main() { + let mut rng = rand::thread_rng(); + for _ in 0..16 { + let num = rng.gen::(); + println!("{num}"); + } +} diff --git a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf index 339003c773..bf7a3db101 100755 Binary files a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-add/src/main.rs b/tests/secp256k1-add/src/main.rs index c45601bcc8..9640e4e8c7 100644 --- a/tests/secp256k1-add/src/main.rs +++ b/tests/secp256k1-add/src/main.rs @@ -6,41 +6,43 @@ extern "C" { } pub fn main() { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let mut a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, - 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, - 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, - 38, 119, 218, 58, 72, - ]; + for _ in 0..4 { + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let mut a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, + 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, + 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, + 101, 196, 163, 38, 119, 218, 58, 72, + ]; - // 2 * generator. - // 89565891926547004231252920425935692360644145829622209833684329913297188986597 - // 12158399299693830322967808612713398636155367887041628176798871954788371653930 - let b: [u8; 64] = [ - 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, 192, - 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, 49, 100, - 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, 163, 57, 195, - 61, 166, 254, 104, 225, 26, - ]; + // 2 * generator. + // 89565891926547004231252920425935692360644145829622209833684329913297188986597 + // 12158399299693830322967808612713398636155367887041628176798871954788371653930 + let b: [u8; 64] = [ + 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, + 192, 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, + 49, 100, 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, + 163, 57, 195, 61, 166, 254, 104, 225, 26, + ]; - unsafe { - syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, - 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, - 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, - 123, 143, 56, - ]; + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, + 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, + 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, + 232, 45, 99, 15, 123, 143, 56, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf index e06da48d78..2fae11204b 100755 Binary files a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-decompress/src/main.rs b/tests/secp256k1-decompress/src/main.rs index 6dc18a25ce..a603986e00 100644 --- a/tests/secp256k1-decompress/src/main.rs +++ b/tests/secp256k1-decompress/src/main.rs @@ -8,20 +8,22 @@ extern "C" { pub fn main() { let compressed_key: [u8; 33] = sp1_zkvm::io::read_vec().try_into().unwrap(); - let mut decompressed_key: [u8; 64] = [0; 64]; - decompressed_key[..32].copy_from_slice(&compressed_key[1..]); - let is_odd = match compressed_key[0] { - 2 => false, - 3 => true, - _ => panic!("Invalid compressed key"), - }; - unsafe { - syscall_secp256k1_decompress(&mut decompressed_key, is_odd); - } + for _ in 0..4 { + let mut decompressed_key: [u8; 64] = [0; 64]; + decompressed_key[..32].copy_from_slice(&compressed_key[1..]); + let is_odd = match compressed_key[0] { + 2 => false, + 3 => true, + _ => panic!("Invalid compressed key"), + }; + unsafe { + syscall_secp256k1_decompress(&mut decompressed_key, is_odd); + } - let mut result: [u8; 65] = [0; 65]; - result[0] = 4; - result[1..].copy_from_slice(&decompressed_key); + let mut result: [u8; 65] = [0; 65]; + result[0] = 4; + result[1..].copy_from_slice(&decompressed_key); - sp1_zkvm::io::commit_slice(&result); + sp1_zkvm::io::commit_slice(&result); + } } diff --git a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf index 6ad007626d..79a156fcab 100755 Binary files a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf index ec0db8bd02..d3e17ead66 100755 Binary files a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/src/main.rs b/tests/secp256k1-mul/src/main.rs index 731a81b381..a2fb6a3dd3 100644 --- a/tests/secp256k1-mul/src/main.rs +++ b/tests/secp256k1-mul/src/main.rs @@ -6,37 +6,39 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, - 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, - 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, - 38, 119, 218, 58, 72, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: secp256k1_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: secp256k1_mul"); - - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, - 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, - 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, - 123, 143, 56, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, + 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, + 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, + 101, 196, 163, 38, 119, 218, 58, 72, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: secp256k1_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: secp256k1_mul"); + + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, + 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, + 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, + 232, 45, 99, 15, 123, 143, 56, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf index 97126f881c..f10443e120 100755 Binary files a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha-compress/src/main.rs b/tests/sha-compress/src/main.rs index bdddab1662..3c306966c2 100644 --- a/tests/sha-compress/src/main.rs +++ b/tests/sha-compress/src/main.rs @@ -6,6 +6,10 @@ use sp1_zkvm::syscalls::syscall_sha256_compress; pub fn main() { let mut w = [1u32; 64]; let mut state = [1u32; 8]; - syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); + + for _ in 0..4 { + syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); + } + println!("{:?}", state); } diff --git a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf index 7b8774766b..d584e1c358 100755 Binary files a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha2/elf/riscv32im-succinct-zkvm-elf b/tests/sha2/elf/riscv32im-succinct-zkvm-elf index ff4661defc..2c63e6648b 100755 Binary files a/tests/sha2/elf/riscv32im-succinct-zkvm-elf and b/tests/sha2/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf index d67e8be4d5..526fc2f836 100755 Binary files a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf and b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf index 83a521f1b2..a80353d4f2 100755 Binary files a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/src/main.rs b/tests/uint256-arith/src/main.rs index 872af3bf60..72eb00dbe7 100644 --- a/tests/uint256-arith/src/main.rs +++ b/tests/uint256-arith/src/main.rs @@ -3,7 +3,6 @@ #![no_main] sp1_zkvm::entrypoint!(main); -use crypto_bigint::NonZero; use crypto_bigint::{Wrapping, U256}; use std::hint::black_box; @@ -25,38 +24,27 @@ pub fn uint256_mul(a: U256, b: U256) -> U256 { result.0 } -#[sp1_derive::cycle_tracker] -pub fn uint256_div(a: U256, b: U256) -> U256 { - Wrapping(a) - .0 - .wrapping_div_vartime(&Wrapping(NonZero::new(b).unwrap()).0) -} - pub fn main() { let a = U256::from(3u8); let b = U256::from(2u8); - println!("cycle-tracker-start: uint256_add"); - let add = uint256_add(black_box(a), black_box(b)); - assert_eq!(add, U256::from(5u8)); - println!("cycle-tracker-end: uint256_add"); - println!("{:?}", add); - - println!("cycle-tracker-start: uint256_sub"); - let sub = uint256_sub(black_box(a), black_box(b)); - assert_eq!(sub, U256::from(1u8)); - println!("cycle-tracker-end: uint256_sub"); - println!("{:?}", sub); - - println!("cycle-tracker-start: uint256_div"); - let div = uint256_div(black_box(a), black_box(b)); - assert_eq!(div, U256::from(1u8)); - println!("cycle-tracker-end: uint256_div"); - println!("{:?}", div); - - println!("cycle-tracker-start: uint256_mul"); - let mul = uint256_mul(black_box(a), black_box(b)); - assert_eq!(mul, U256::from(6u8)); - println!("cycle-tracker-end: uint256_mul"); - println!("{:?}", mul); + for _ in 0..4 { + println!("cycle-tracker-start: uint256_add"); + let add = uint256_add(black_box(a), black_box(b)); + assert_eq!(add, U256::from(5u8)); + println!("cycle-tracker-end: uint256_add"); + println!("{:?}", add); + + println!("cycle-tracker-start: uint256_sub"); + let sub = uint256_sub(black_box(a), black_box(b)); + assert_eq!(sub, U256::from(1u8)); + println!("cycle-tracker-end: uint256_sub"); + println!("{:?}", sub); + + println!("cycle-tracker-start: uint256_mul"); + let mul = uint256_mul(black_box(a), black_box(b)); + assert_eq!(mul, U256::from(6u8)); + println!("cycle-tracker-end: uint256_mul"); + println!("{:?}", mul); + } } diff --git a/tests/uint256-div/Cargo.lock b/tests/uint256-div/Cargo.lock deleted file mode 100644 index 136db8c0b0..0000000000 --- a/tests/uint256-div/Cargo.lock +++ /dev/null @@ -1,772 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "anyhow" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" - -[[package]] -name = "arrayvec" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" - -[[package]] -name = "autocfg" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" - -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - -[[package]] -name = "bitvec" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "byte-slice-cast" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - -[[package]] -name = "cpufeatures" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" -dependencies = [ - "libc", -] - -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core", - "subtle", - "zeroize", -] - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "der" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" -dependencies = [ - "const-oid", - "zeroize", -] - -[[package]] -name = "derive_more" -version = "0.99.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "const-oid", - "crypto-common", - "subtle", -] - -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest", - "elliptic-curve", - "rfc6979", - "signature", - "spki", -] - -[[package]] -name = "elliptic-curve" -version = "0.13.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest", - "ff", - "generic-array", - "group", - "pkcs8", - "rand_core", - "sec1", - "subtle", - "tap", - "zeroize", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "ff" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" -dependencies = [ - "bitvec", - "rand_core", - "subtle", -] - -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", - "zeroize", -] - -[[package]] -name = "getrandom" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core", - "subtle", -] - -[[package]] -name = "hashbrown" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "impl-trait-for-tuples" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "indexmap" -version = "2.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "k256" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "956ff9b67e26e1a6a866cb758f12c6f8746208489e3e4a4b5580802f2f0a587b" -dependencies = [ - "cfg-if", - "ecdsa", - "elliptic-curve", - "once_cell", - "sha2", - "signature", -] - -[[package]] -name = "libc" -version = "0.2.155" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" - -[[package]] -name = "libm" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" - -[[package]] -name = "memchr" -version = "2.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" - -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] -name = "num-bigint" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-complex" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "parity-scale-codec" -version = "3.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881331e34fa842a2fb61cc2db9643a8fedc615e47cfcc52597d1af0db9a7e8fe" -dependencies = [ - "arrayvec", - "byte-slice-cast", - "impl-trait-for-tuples", - "parity-scale-codec-derive", -] - -[[package]] -name = "parity-scale-codec-derive" -version = "3.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be30eaf4b0a9fba5336683b38de57bb86d179a35862ba6bfcf57625d006bde5b" -dependencies = [ - "proc-macro-crate 2.0.2", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro-crate" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" -dependencies = [ - "once_cell", - "toml_edit 0.19.15", -] - -[[package]] -name = "proc-macro-crate" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b00f26d3400549137f92511a46ac1cd8ce37cb5598a96d382381458b992a5d24" -dependencies = [ - "toml_datetime", - "toml_edit 0.20.2", -] - -[[package]] -name = "proc-macro2" -version = "1.0.79" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "radium" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac", - "subtle", -] - -[[package]] -name = "scale-info" -version = "2.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c453e59a955f81fb62ee5d596b450383d699f152d350e9d23a0db2adb78e4c0" -dependencies = [ - "cfg-if", - "derive_more", - "parity-scale-codec", - "scale-info-derive", -] - -[[package]] -name = "scale-info-derive" -version = "2.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18cf6c6447f813ef19eb450e985bcce6705f9ce7660db221b59093d15c79c4b7" -dependencies = [ - "proc-macro-crate 1.3.1", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "sec1" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" -dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "subtle", - "zeroize", -] - -[[package]] -name = "serde" -version = "1.0.203" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.203" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.55", -] - -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "signature" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest", - "rand_core", -] - -[[package]] -name = "snowbridge-amcl" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460a9ed63cdf03c1b9847e8a12a5f5ba19c4efd5869e4a737e05be25d7c427e5" -dependencies = [ - "parity-scale-codec", - "scale-info", -] - -[[package]] -name = "sp1-derive" -version = "0.1.0" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "sp1-precompiles" -version = "0.1.0" -dependencies = [ - "anyhow", - "bincode", - "cfg-if", - "getrandom", - "hex", - "k256", - "num", - "rand", - "serde", - "snowbridge-amcl", -] - -[[package]] -name = "sp1-zkvm" -version = "0.1.0" -dependencies = [ - "bincode", - "cfg-if", - "getrandom", - "k256", - "libm", - "once_cell", - "rand", - "serde", - "sha2", - "sp1-precompiles", -] - -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - -[[package]] -name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.55" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - -[[package]] -name = "toml_datetime" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" - -[[package]] -name = "toml_edit" -version = "0.19.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap", - "toml_datetime", - "winnow", -] - -[[package]] -name = "toml_edit" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" -dependencies = [ - "indexmap", - "toml_datetime", - "winnow", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "uint256-div-test" -version = "0.1.0" -dependencies = [ - "num", - "rand", - "sp1-derive", - "sp1-zkvm", -] - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "winnow" -version = "0.5.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" -dependencies = [ - "memchr", -] - -[[package]] -name = "wyz" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" -dependencies = [ - "tap", -] - -[[package]] -name = "zeroize" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf deleted file mode 100755 index ca4f0641f3..0000000000 Binary files a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf and /dev/null differ diff --git a/tests/uint256-div/src/main.rs b/tests/uint256-div/src/main.rs deleted file mode 100644 index a4bf44e629..0000000000 --- a/tests/uint256-div/src/main.rs +++ /dev/null @@ -1,71 +0,0 @@ -#![no_main] -sp1_zkvm::entrypoint!(main); - -use num::BigUint; -use rand::Rng; - -use sp1_zkvm::precompiles::uint256_div::uint256_div; - -#[sp1_derive::cycle_tracker] -fn main() { - // Random test. - for _ in 0..100 { - // generate random dividend and divisor. - let mut rng = rand::thread_rng(); - let mut dividend: [u8; 32] = rng.gen(); - let divisor: [u8; 32] = rng.gen(); - - // Skip division by zero - if divisor == [0; 32] { - continue; - } - - // Convert byte arrays to BigUint for validation. - let dividend_big = BigUint::from_bytes_le(÷nd); - let divisor_big = BigUint::from_bytes_le(&divisor); - - let quotient_big = ÷nd_big / &divisor_big; - - // Perform division. - let quotient = uint256_div(&mut dividend, &divisor); - - let quotient_precompile_big = BigUint::from_bytes_le("ient); - - // Check if the product of quotient and divisor equals the dividend - assert_eq!( - quotient_precompile_big, quotient_big, - "Quotient should match." - ); - } - - // Hardcoded edge case: division by 1. - let mut rng = rand::thread_rng(); - let mut dividend: [u8; 32] = rng.gen(); - let mut divisor = [0; 32]; - divisor[0] = 1; - - let expected_quotient: [u8; 32] = dividend.clone(); - - let quotient = uint256_div(&mut dividend, &divisor); - assert_eq!( - quotient, expected_quotient, - "Dividing by 1 should yield the same number." - ); - - // Hardcoded edge case: when the dividend is smaller thant the divisor. - // In this case, the quotient should be zero. - let mut dividend = [0; 32]; - dividend[0] = 1; - let mut divisor = [0; 32]; - divisor[0] = 4; - - let expected_quotient: [u8; 32] = [0; 32]; - - let quotient = uint256_div(&mut dividend, &divisor); - assert_eq!( - quotient, expected_quotient, - "The quotient should be zero when the dividend is smaller than the divisor." - ); - - println!("All tests passed."); -} diff --git a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf index b49313ca61..2ab53b56d0 100755 Binary files a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf index 63f478712f..93c15811cb 100755 Binary files a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf and b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf differ diff --git a/zkvm/entrypoint/Cargo.toml b/zkvm/entrypoint/Cargo.toml index 1c69b11628..5cdea96798 100644 --- a/zkvm/entrypoint/Cargo.toml +++ b/zkvm/entrypoint/Cargo.toml @@ -17,6 +17,7 @@ rand = "0.8.5" serde = { version = "1.0.201", features = ["derive"] } libm = { version = "0.2.8", optional = true } sha2 = { version = "0.10.8" } +lazy_static = "1.4.0" [features] default = ["libm"] diff --git a/zkvm/entrypoint/src/lib.rs b/zkvm/entrypoint/src/lib.rs index 11a79f5ca6..51cf158a3d 100644 --- a/zkvm/entrypoint/src/lib.rs +++ b/zkvm/entrypoint/src/lib.rs @@ -90,24 +90,16 @@ mod zkvm { .option pop; la sp, {0} lw sp, 0(sp) - jal ra, __start; + call __start; "#, sym STACK_TOP ); - static GETRANDOM_WARNING_ONCE: std::sync::Once = std::sync::Once::new(); - fn zkvm_getrandom(s: &mut [u8]) -> Result<(), Error> { - use rand::Rng; - use rand::SeedableRng; - - GETRANDOM_WARNING_ONCE.call_once(|| { - println!("WARNING: Using insecure random number generator"); - }); - let mut rng = rand::rngs::StdRng::seed_from_u64(123); - for i in 0..s.len() { - s[i] = rng.gen(); + unsafe { + crate::syscalls::sys_rand(s.as_mut_ptr(), s.len()); } + Ok(()) } diff --git a/zkvm/precompiles/src/bigint_mulmod.rs b/zkvm/entrypoint/src/syscalls/bigint.rs similarity index 95% rename from zkvm/precompiles/src/bigint_mulmod.rs rename to zkvm/entrypoint/src/syscalls/bigint.rs index c61746fd0b..ee05211be2 100644 --- a/zkvm/precompiles/src/bigint_mulmod.rs +++ b/zkvm/entrypoint/src/syscalls/bigint.rs @@ -1,6 +1,6 @@ -use crate::syscall_uint256_mulmod; +use sp1_precompiles::BIGINT_WIDTH_WORDS; -const BIGINT_WIDTH_WORDS: usize = 8; +use super::syscall_uint256_mulmod; /// Sets result to be (x op y) % modulus. Currently only multiplication is supported. If modulus is /// zero, the modulus applied is 2^256. diff --git a/zkvm/entrypoint/src/syscalls/blake3_compress.rs b/zkvm/entrypoint/src/syscalls/blake3_compress.rs deleted file mode 100644 index 05a7427fe2..0000000000 --- a/zkvm/entrypoint/src/syscalls/blake3_compress.rs +++ /dev/null @@ -1,22 +0,0 @@ -#[cfg(target_os = "zkvm")] -use core::arch::asm; - -/// Blake3 compress operation. -/// -/// The result is written over the input state. -#[allow(unused_variables)] -#[no_mangle] -pub extern "C" fn syscall_blake3_compress_inner(state: *mut u32, message: *mut u32) { - #[cfg(target_os = "zkvm")] - unsafe { - asm!( - "ecall", - in("t0") crate::syscalls::BLAKE3_COMPRESS_INNER, - in("a0") state, - in("a1") message - ); - } - - #[cfg(not(target_os = "zkvm"))] - unreachable!() -} diff --git a/zkvm/entrypoint/src/syscalls/mod.rs b/zkvm/entrypoint/src/syscalls/mod.rs index 760cabad2c..50be5d7b04 100644 --- a/zkvm/entrypoint/src/syscalls/mod.rs +++ b/zkvm/entrypoint/src/syscalls/mod.rs @@ -1,4 +1,4 @@ -mod blake3_compress; +mod bigint; mod bls12381; mod bn254; mod ed25519; @@ -70,9 +70,6 @@ pub const SECP256K1_DOUBLE: u32 = 0x00_00_01_0B; /// Executes `K256_DECOMPRESS`. pub const SECP256K1_DECOMPRESS: u32 = 0x00_00_01_0C; -/// Executes `BLAKE3_COMPRESS_INNER`. -pub const BLAKE3_COMPRESS_INNER: u32 = 0x00_38_01_0D; - /// Executes `BN254_ADD`. pub const BN254_ADD: u32 = 0x00_01_01_0E; diff --git a/zkvm/entrypoint/src/syscalls/sys.rs b/zkvm/entrypoint/src/syscalls/sys.rs index 2a62372a62..520d58ef90 100644 --- a/zkvm/entrypoint/src/syscalls/sys.rs +++ b/zkvm/entrypoint/src/syscalls/sys.rs @@ -1,5 +1,40 @@ +use std::sync::Mutex; + +use lazy_static::lazy_static; +use rand::{rngs::StdRng, Rng, SeedableRng}; + use crate::syscalls::{syscall_halt, syscall_write}; +/// The random number generator seed for the zkVM. +/// +/// In the future, we can pass in this seed from the host or have the verifier generate it. +const PRNG_SEED: u64 = 0x123456789abcdef0; + +lazy_static! { + /// A lazy static to generate a global random number generator. + static ref RNG: Mutex = Mutex::new(StdRng::seed_from_u64(PRNG_SEED)); +} + +/// A lazy static to print a warning once for using the `sys_rand` system call. +static SYS_RAND_WARNING: std::sync::Once = std::sync::Once::new(); + +/// Generates random bytes. +/// +/// # Safety +/// +/// Make sure that `buf` has at least `nwords` words. +#[no_mangle] +pub unsafe extern "C" fn sys_rand(recv_buf: *mut u8, words: usize) { + SYS_RAND_WARNING.call_once(|| { + println!("WARNING: Using insecure random number generator."); + }); + let mut rng = RNG.lock().unwrap(); + for i in 0..words { + let element = recv_buf.add(i); + *element = rng.gen(); + } +} + #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn sys_panic(msg_ptr: *const u8, len: usize) -> ! { diff --git a/zkvm/precompiles/src/io.rs b/zkvm/precompiles/src/io.rs index 3fccf95102..cc1be96ed8 100644 --- a/zkvm/precompiles/src/io.rs +++ b/zkvm/precompiles/src/io.rs @@ -8,6 +8,9 @@ use std::io::Write; const FD_HINT: u32 = 4; pub const FD_PUBLIC_VALUES: u32 = 3; +// Runtime hook file descriptors. Make sure these match the FDs in the HookRegistry. +// The default hooks can be found in `core/src/runtime/hooks.rs`. +pub const FD_ECRECOVER_HOOK: u32 = 5; pub struct SyscallWriter { fd: u32, @@ -80,3 +83,8 @@ pub fn hint_slice(buf: &[u8]) { let mut my_reader = SyscallWriter { fd: FD_HINT }; my_reader.write_all(buf).unwrap(); } + +/// Write the data `buf` to the file descriptor `fd` using `Write::write_all` . +pub fn write(fd: u32, buf: &[u8]) { + SyscallWriter { fd }.write_all(buf).unwrap(); +} diff --git a/zkvm/precompiles/src/lib.rs b/zkvm/precompiles/src/lib.rs index a31134c12e..4733de84d6 100644 --- a/zkvm/precompiles/src/lib.rs +++ b/zkvm/precompiles/src/lib.rs @@ -1,14 +1,21 @@ -pub mod bigint_mulmod; +//! Precompiles for SP1 zkVM. +//! +//! Specifically, this crate contains user-friendly functions that call SP1 syscalls. Syscalls are +//! also declared here for convenience. In order to avoid duplicate symbol errors, the syscall +//! function impls must live in sp1-zkvm, which is only imported into the end user program crate. +//! In contrast, sp1-precompiles can be imported into any crate in the dependency tree. + pub mod bls12381; pub mod bn254; pub mod io; pub mod secp256k1; -pub mod uint256_div; pub mod unconstrained; pub mod utils; #[cfg(feature = "verify")] pub mod verify; +pub const BIGINT_WIDTH_WORDS: usize = 8; + extern "C" { pub fn syscall_halt(exit_code: u8) -> !; pub fn syscall_write(fd: u32, write_buf: *const u8, nbytes: usize); @@ -26,7 +33,6 @@ extern "C" { pub fn syscall_bls12381_double(p: *mut u32); pub fn syscall_keccak_permute(state: *mut u64); pub fn syscall_uint256_mulmod(x: *mut u32, y: *const u32); - pub fn syscall_blake3_compress_inner(p: *mut u32, q: *const u32); pub fn syscall_enter_unconstrained() -> bool; pub fn syscall_exit_unconstrained(); pub fn syscall_verify_sp1_proof(vkey: &[u32; 8], pv_digest: &[u8; 32]); @@ -34,4 +40,11 @@ extern "C" { pub fn syscall_hint_read(ptr: *mut u8, len: usize); pub fn sys_alloc_aligned(bytes: usize, align: usize) -> *mut u8; pub fn syscall_bls12381_decompress(point: &mut [u8; 96], is_odd: bool); + pub fn sys_bigint( + result: *mut [u32; BIGINT_WIDTH_WORDS], + op: u32, + x: *const [u32; BIGINT_WIDTH_WORDS], + y: *const [u32; BIGINT_WIDTH_WORDS], + modulus: *const [u32; BIGINT_WIDTH_WORDS], + ); } diff --git a/zkvm/precompiles/src/secp256k1.rs b/zkvm/precompiles/src/secp256k1.rs index 5bbc59a000..46ccf91f0b 100644 --- a/zkvm/precompiles/src/secp256k1.rs +++ b/zkvm/precompiles/src/secp256k1.rs @@ -14,7 +14,7 @@ use k256::elliptic_curve::sec1::ToEncodedPoint; use k256::elliptic_curve::PrimeField; use k256::{PublicKey, Scalar, Secp256k1}; -use crate::io; +use crate::io::{self, FD_ECRECOVER_HOOK}; use crate::unconstrained; const NUM_WORDS: usize = 16; @@ -189,23 +189,15 @@ fn double_and_add_base( /// Either use `decompress_pubkey` and `verify_signature` to verify the results of this function, or /// use `ecrecover`. pub fn unconstrained_ecrecover(sig: &[u8; 65], msg_hash: &[u8; 32]) -> ([u8; 33], Scalar) { + // The `unconstrained!` wrapper is used since none of these computations directly affect + // the output values of the VM. The remainder of the function sets the constraints on the values + // instead. Removing the `unconstrained!` wrapper slightly increases the cycle count. unconstrained! { - let mut recovery_id = sig[64]; - let mut sig = Signature::from_slice(&sig[..64]).unwrap(); - - if let Some(sig_normalized) = sig.normalize_s() { - sig = sig_normalized; - recovery_id ^= 1 - }; - let recid = RecoveryId::from_byte(recovery_id).expect("Recovery ID is valid"); - - let recovered_key = VerifyingKey::recover_from_prehash(&msg_hash[..], &sig, recid).unwrap(); - let bytes = recovered_key.to_sec1_bytes(); - io::hint_slice(&bytes); - - let (_, s) = sig.split_scalars(); - let s_inverse = s.invert(); - io::hint_slice(&s_inverse.to_bytes()); + let mut buf = [0; 65 + 32]; + let (buf_sig, buf_msg_hash) = buf.split_at_mut(sig.len()); + buf_sig.copy_from_slice(sig); + buf_msg_hash.copy_from_slice(msg_hash); + io::write(FD_ECRECOVER_HOOK, &buf); } let recovered_bytes: [u8; 33] = io::read_vec().try_into().unwrap(); diff --git a/zkvm/precompiles/src/uint256_div.rs b/zkvm/precompiles/src/uint256_div.rs deleted file mode 100644 index c12a07b908..0000000000 --- a/zkvm/precompiles/src/uint256_div.rs +++ /dev/null @@ -1,63 +0,0 @@ -#![allow(unused_imports)] -use crate::bigint_mulmod::sys_bigint; -use crate::io; -use crate::syscall_uint256_mulmod; -use crate::unconstrained; -use num::{BigUint, Integer}; - -/// Performs division on 256-bit unsigned integers represented as little endian byte arrays. -/// -/// This function divides `x` by `y`, both of which are 256-bit unsigned integers -/// represented as arrays of bytes in little-endian order. It returns the quotient -/// of the division as a 256-bit unsigned integer in the same byte array format. -pub fn uint256_div(x: &mut [u8; 32], y: &[u8; 32]) -> [u8; 32] { - // TODO: this will panic now. - // Assert that the divisor is not zero. - assert!(y != &[0; 32], "division by zero"); - cfg_if::cfg_if! { - if #[cfg(all(target_os = "zkvm", target_vendor = "succinct"))] { - let dividend = BigUint::from_bytes_le(x); - - unconstrained!{ - let divisor = BigUint::from_bytes_le(y); - let (quotient, remainder) = dividend.div_rem(&divisor); - - let mut quotient_bytes = quotient.to_bytes_le(); - quotient_bytes.resize(32, 0u8); - io::hint_slice("ient_bytes); - - let mut remainder_bytes = remainder.to_bytes_le(); - remainder_bytes.resize(32, 0u8); - io::hint_slice(&remainder_bytes); - }; - - let quotient_bytes: [u8; 32] = io::read_vec().try_into().unwrap(); - - let remainder_bytes: [u8; 32] = io::read_vec().try_into().unwrap(); - - let remainder = BigUint::from_bytes_le(&remainder_bytes); - - *x = quotient_bytes; - - let mut quotient_times_y = [0u8; 32]; - let zero = [0u32; 8]; - sys_bigint( - quotient_times_y.as_mut_ptr() as *mut [u32; 8], - 0, - quotient_bytes.as_ptr() as *const [u32; 8], - y.as_ptr() as *const [u32; 8], - zero.as_ptr() as *const [u32; 8] - ); - - let quotient_times_divisor = BigUint::from_bytes_le("ient_times_y); - assert_eq!(quotient_times_divisor, dividend - remainder); - - *x - } else { - let result_biguint = BigUint::from_bytes_le(x) / BigUint::from_bytes_le(y); - let mut result_biguint_bytes = result_biguint.to_bytes_le(); - result_biguint_bytes.resize(32, 0u8); - result_biguint_bytes.try_into().unwrap_or([0; 32]) - } - } -}