Skip to content

Commit

Permalink
Cache key when executing
Browse files Browse the repository at this point in the history
  • Loading branch information
iamalwaysuncomfortable committed Oct 25, 2023
1 parent 861bb94 commit f627b7c
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 112 deletions.
145 changes: 50 additions & 95 deletions create-aleo-app/template-react-zkml/src/Main.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import {
import { ReactSketchCanvas } from "react-sketch-canvas";
import { Column } from "@ant-design/charts";
import aleoLogo from "./assets/aleo.svg";
import {Account, initThreadPool, PrivateKey, ProgramManager, AleoKeyProvider, AleoKeyProviderParams, } from "@aleohq/sdk";
import { AleoWorker } from "./workers/AleoWorker.js";

import { mlp_program, decision_tree_program, test_imageData } from './variables.js';
Expand All @@ -35,31 +34,9 @@ let used_model_type, proving_start_time, proving_end_time;

const aleoWorker = AleoWorker();

const programManager = new ProgramManager();
const generateAccount = async () => {
console.log("generating account")
const key = await aleoWorker.getPrivateKey();
//setAccount(await key.to_string());
};

generateAccount();

const account = new Account();
programManager.setAccount(account);

const keyProvider = new AleoKeyProvider();
keyProvider.useCache = true;
programManager.setKeyProvider(keyProvider);

const keyPair = await programManager.synthesizeKeys(decision_tree_program, "main", ['{x0: -1i64, x1: 5i64}', '{x0: -6i64, x1: 12i64}', '{x0: -4i64, x1: -4i64}', '{x0: 3i64, x1: -1i64}', '{x0: -2i64}', '{x0: 7i64}', '{x0: -6i64}', '{x0: -1i64}', '{x0: 9i64}', '{x0: 6i64}', '{x0: 6i64}', '{x0: -5i64}', '{x0: -12i64}', '{x0: -7i64}', '{x0: -43i64}', '{x0: -12i64}']);
programManager.keyProvider.cacheKeys("hello_hello.aleo:hello", keyPair);

const keyProviderParams = new AleoKeyProviderParams({cacheKey: "tree_mnist_1.aleo:main"});

const Main = () => {
const [account, setAccount] = useState(null);


async function execute(features) {

// helpful tool: https://codepen.io/jsnelders/pen/qBByqQy
Expand Down Expand Up @@ -87,92 +64,70 @@ const Main = () => {
model = mlp_program;
used_model_type = "mlp";
}

const old_proving_method = false;

proving_start_time = performance.now();

console.log("before execution")

let result;

if(old_proving_method) {
result = await aleoWorker.localProgramExecution(
model,
"main",
input_array,
true
);
}
console.log("before execution ")

if(!old_proving_method) {
let executionResponse = await programManager.executeOffline(
const result = await aleoWorker.localProgramExecution(
model,
"main",
input_array,
false,
undefined,
keyProviderParams,
true
);

console.log("executionResponse", executionResponse)
result = executionResponse.getOutputs();
}
proving_end_time = performance.now();
console.log("proving time in seconds", (proving_end_time - proving_start_time) / 1000);


proving_end_time = performance.now();
console.log("proving time in seconds", (proving_end_time - proving_start_time) / 1000);
//const execution = result.getExecution();

//const execution = result.getExecution();

console.log("result", result);
//console.log("execution", execution);
console.log("result", result);
//console.log("execution", execution);

let output_fixed_point_scaling_factor;
let output_fixed_point_scaling_factor;

if(used_model_type == "tree") {
output_fixed_point_scaling_factor = fixed_point_scaling_factor;
}
else if(used_model_type == "mlp") {
output_fixed_point_scaling_factor = fixed_point_scaling_factor**3;
}
if(used_model_type == "tree") {
output_fixed_point_scaling_factor = fixed_point_scaling_factor;
}
else if(used_model_type == "mlp") {
output_fixed_point_scaling_factor = fixed_point_scaling_factor**3;
}

// empty array
var converted_features = [];
// empty array
var converted_features = [];

// iterate over result. For each entry, remove "i64", convert to a number, and divide by the scaling factor
for (let i = 0; i < result.length; i++) {
var output = result[i].replace("i64", "");
output = Number(output);
output = output / output_fixed_point_scaling_factor;
converted_features.push(output);
}
// iterate over result. For each entry, remove "i64", convert to a number, and divide by the scaling factor
for (let i = 0; i < result.length; i++) {
var output = result[i].replace("i64", "");
output = Number(output);
output = output / output_fixed_point_scaling_factor;
converted_features.push(output);
}

console.log("converted_features", converted_features);
console.log("converted_features", converted_features);

if(used_model_type == "mlp") {
const argmax_index = converted_features.indexOf(Math.max(...converted_features));
console.log("argmax_index", argmax_index);
if(used_model_type == "mlp") {
const argmax_index = converted_features.indexOf(Math.max(...converted_features));
console.log("argmax_index", argmax_index);

// compute softmax of converted_features
var softmax = [];
var sum = 0;
for (let i = 0; i < converted_features.length; i++) {
softmax.push(Math.exp(converted_features[i]));
sum += Math.exp(converted_features[i]);
}
for (let i = 0; i < converted_features.length; i++) {
softmax[i] = softmax[i] / sum;
}
console.log("softmax", softmax);

setChartData(
chartData.map((item, index) => ({
...item,
value: softmax[index] * 100, // multiply by 100 if you want to scale it up
})),
);
// compute softmax of converted_features
var softmax = [];
var sum = 0;
for (let i = 0; i < converted_features.length; i++) {
softmax.push(Math.exp(converted_features[i]));
sum += Math.exp(converted_features[i]);
}
for (let i = 0; i < converted_features.length; i++) {
softmax[i] = softmax[i] / sum;
}
console.log("softmax", softmax);

setChartData(
chartData.map((item, index) => ({
...item,
value: softmax[index] * 100, // multiply by 100 if you want to scale it up
})),
);
}

alert(JSON.stringify(converted_features));

Expand All @@ -197,10 +152,6 @@ const Main = () => {
console.error("Failed to get top left pixel data:", error);
}
};





const canvasRef = useRef(null);
const [progress, setProgress] = useState(0);
Expand All @@ -215,6 +166,10 @@ const Main = () => {
setHasMounted(true);
}, []);

async function synthesizeKeys(model) {
await aleoWorker.synthesizeKeys(model, "main");
}

const menuItems = ["Even/Odd", "Number Range", "Classification"].map(
(label, index) => ({
key: String(index + 1),
Expand Down
4 changes: 3 additions & 1 deletion create-aleo-app/template-react-zkml/src/variables.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 48 additions & 11 deletions create-aleo-app/template-react-zkml/src/workers/worker.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,64 @@ import {
initThreadPool,
AleoKeyProvider,
AleoNetworkClient,
NetworkRecordProvider,
Program,
NetworkRecordProvider, AleoKeyProviderParams,
} from "@aleohq/sdk";
import { expose, proxy } from "comlink";
import {sample_inputs} from "../variables.js";

// Initialize the threadpool
await initThreadPool();

async function localProgramExecution(program, aleoFunction, inputs) {
const programManager = new ProgramManager();
// Initialize a program manager with a keyprovider that will cache our keys
const keyProvider = new AleoKeyProvider();
keyProvider.useCache(true);
const programManager = new ProgramManager({ host: "https://api.explorer.aleo.org/v1", keyProvider: keyProvider});
const account = new Account();
programManager.setAccount(account);

async function synthesizeKeys(program_source, aleoFunction) {
const program = Program.fromString(program_source);
const keys = await programManager.synthesizeKeys(program_source, aleoFunction, sample_inputs, new PrivateKey())
const cacheKey = `${program.id()}/${aleoFunction}`;
programManager.keyProvider.cacheKeys(cacheKey, keys);
console.log(`Synthesized keys for ${cacheKey}`);
}

// Create a temporary account for the execution of the program
const account = new Account();
programManager.setAccount(account);
async function localProgramExecution(program_source, aleoFunction, inputs) {
const program = Program.fromString(program_source);
const keySearchParams = new AleoKeyProviderParams({cacheKey: `${program.id()}/${aleoFunction}`});
let cacheFunctionKeys = false;
if (!programManager.keyProvider.containsKeys(keySearchParams.cacheKey)) {
console.log(`No cached keys for ${keySearchParams.cacheKey}`);
cacheFunctionKeys = true;
}

const executionResponse = await programManager.executeOffline(
program,
program_source,
aleoFunction,
inputs,
false, // set to true to get proof
true, // set to true to get proof
undefined,
keySearchParams,
undefined,
undefined,
undefined,
cacheFunctionKeys
);
//console.log(executionResponse.getExecution()); // toString later
return executionResponse.getOutputs(); // proof: executionResponse.
console.log("executionResponse", executionResponse);

if (cacheFunctionKeys) {
console.log("Caching keys");
const keys = executionResponse.getKeys(program.id(), aleoFunction);
programManager.keyProvider.cacheKeys(keySearchParams.cacheKey, [keys.provingKey(), keys.verifyingKey()]);
console.log(`Cached keys for ${keySearchParams.cacheKey}`);
}

console.log("Getting outputs");
const outputs = executionResponse.getOutputs(); // proof: executionResponse.
console.log("outputs", outputs);
return outputs;
}

async function getPrivateKey() {
Expand Down Expand Up @@ -69,5 +106,5 @@ async function deployProgram(program) {
return tx_id;
}

const workerMethods = { localProgramExecution, getPrivateKey, deployProgram };
const workerMethods = { deployProgram, getPrivateKey, localProgramExecution, synthesizeKeys };
expose(workerMethods);
7 changes: 3 additions & 4 deletions sdk/src/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ class ProgramManager {
* @param {ProvingKey | undefined} provingKey Optional proving key to use for the transaction
* @param {VerifyingKey | undefined} verifyingKey Optional verifying key to use for the transaction
* @param {PrivateKey | undefined} privateKey Optional private key to use for the transaction
* @param {boolean} cache Whether to save the private key from the execution
* @returns {Promise<string | Error>}
*
* @example
Expand Down Expand Up @@ -331,6 +332,7 @@ class ProgramManager {
provingKey?: ProvingKey,
verifyingKey?: VerifyingKey,
privateKey?: PrivateKey,
cache = false,
): Promise<ExecutionResponse> {
// Get the private key from the account if it is not provided in the parameters
let executionPrivateKey = privateKey;
Expand All @@ -352,10 +354,7 @@ class ProgramManager {
}

// Run the program offline and return the result
console.log("Running program offline")
console.log("Proving key: ", provingKey);
console.log("Verifying key: ", verifyingKey);
return WasmProgramManager.executeFunctionOffline(executionPrivateKey, program, function_name, inputs, proveExecution, false, imports, provingKey, verifyingKey);
return WasmProgramManager.executeFunctionOffline(executionPrivateKey, program, function_name, inputs, proveExecution, cache, imports, provingKey, verifyingKey);
}

/**
Expand Down
2 changes: 1 addition & 1 deletion wasm/src/programs/manager/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl ProgramManager {

if prove_execution {
log("Preparing inclusion proofs for execution");
let query = QueryNative::from("https://vm.aleo.org/api");
let query = QueryNative::from("https://api.explorer.aleo.org/v1");
trace.prepare_async(query).await.map_err(|err| err.to_string())?;

log("Proving execution");
Expand Down

0 comments on commit f627b7c

Please sign in to comment.