-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathbuiltin_delegate_worker.js
106 lines (92 loc) · 3.32 KB
/
builtin_delegate_worker.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
'use strict';
/* eslint max-len: ["error", {"code": 120}] */
// built-in webnn delegate
importScripts('./tflite-support/tflite_model_runner_cc_simd.js');
let modelRunnerResult;
let modelRunner;
// Receive the message from the main thread
onmessage = async (message) => {
if (message) {
// Load model or infer depends on the first data
switch (message.data.action) {
case 'load': {
if (modelRunner) {
modelRunner.delete();
}
const loadStart = performance.now();
const modelPath = message.data.modelPath;
// Load WASM module and model.
const [module, modelArrayBuffer] = await Promise.all([
tflite_model_runner_ModuleFactory(),
(await fetch(modelPath)).arrayBuffer(),
]);
// Load WASM module and model.
const modelBytes = new Uint8Array(modelArrayBuffer);
const offset = module._malloc(modelBytes.length);
module.HEAPU8.set(modelBytes, offset);
// Create model runner.
modelRunnerResult =
module.TFLiteWebModelRunner.CreateFromBufferAndOptions(
offset,
modelBytes.length,
{
numThreads: 1,
enableWebNNDelegate: message.data.enableWebNNDelegate,
webNNDevicePreference: parseInt(message.data.webNNDevicePreference),
webNNNumThreads: parseInt(message.data.webNNNumThreads),
},
);
if (!modelRunnerResult.ok()) {
throw new Error(
`Failed to create TFLiteWebModelRunner: ${modelRunner.errorMessage()}`);
}
modelRunner = modelRunnerResult.value();
const loadFinishedMs = (performance.now() - loadStart).toFixed(2);
postMessage(loadFinishedMs);
break;
}
case 'compute': {
// Get input and output info.
const inputs = callAndDelete(modelRunner.GetInputs(), (results) => convertCppVectorToArray(results));
const input = inputs[0];
const outputs = callAndDelete(modelRunner.GetOutputs(), (results) => convertCppVectorToArray(results));
const output = outputs[0];
// Set input tensor data from the image (224 x 224 x 3).
const inputBuffer = input.data();
inputBuffer.set(message.data.buffer);
// Infer, get output tensor, and sort by logit values in reverse.
const inferStart = performance.now();
modelRunner.Infer();
const inferTime = performance.now() - inferStart;
console.log(`Infer time in worker: ${inferTime.toFixed(2)} ms`);
let outputBuffer = output.data();
outputBuffer = outputBuffer.slice(0);
postMessage({outputBuffer}, [outputBuffer.buffer]);
break;
}
default: {
break;
}
}
}
};
// Helper functions.
// Converts the given c++ vector to a JS array.
function convertCppVectorToArray(vector) {
if (vector == null) return [];
const result = [];
for (let i = 0; i < vector.size(); i++) {
const item = vector.get(i);
result.push(item);
}
return result;
}
// Calls the given function with the given deletable argument, ensuring that
// the argument gets deleted afterwards (even if the function throws an error).
function callAndDelete(arg, func) {
try {
return func(arg);
} finally {
if (arg != null) arg.delete();
}
}