Skip to content

Commit

Permalink
Fix emulation error of gru by 'backward' and 'both' direction options
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceDai committed Jan 14, 2025
1 parent df555e1 commit 259b7d3
Showing 1 changed file with 68 additions and 46 deletions.
114 changes: 68 additions & 46 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -4225,20 +4225,26 @@ partial dictionary MLOpSupportLimits {
builder, input, weight, recurrentWeight, steps, hiddenSize, options) {
const batchSize = input.shape[1];
const inputSize = input.shape[2];
const numDirections = (options.direction == 'both' ? 2 : 1);
const direction = options.direction || 'forward';
const numDirections = (direction == 'both' ? 2 : 1);
let hiddenState = options.initialHiddenState;

if (!hiddenState) {
const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]};
const totalSize = numDirections * hiddenSize;
const desc = {
dataType: 'float32',
shape: [numDirections, batchSize, hiddenSize]
};
const totalSize = numDirections * batchSize * hiddenSize;
hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0));
}

let sequence = null;
let currentWeight = [];
let currentRecurrentWeight = [];
let currentBias = [];
let currentRecurrentBias = [];
let forwardSequence = null;
let backwardSequence = null;
let outputHidden = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentWeight.push(squeeze(
Expand All @@ -4261,57 +4267,73 @@ partial dictionary MLOpSupportLimits {
builder.slice(
options.recurrentBias, [dir, 0], [1, 3 * hiddenSize]))) :
null);
}

for (let step = 0; step < steps; ++step) {
let currentHidden = [];
let currentOutput = null;

for (let dir = 0; dir < numDirections; ++dir) {
currentHidden.push(squeeze(
builder,
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])));
}
let currentHidden = squeeze(
builder,
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));

for (let dir = 0; dir < numDirections; ++dir) {
let slice =
(dir == 1 || options.direction == 'backward' ? steps - step - 1 : step);
let currentInput = squeeze(
for (let step = 0; step < steps; ++step) {
const slice = (dir == 1 || direction == 'backward' ? steps - step - 1 : step);
const currentInput = squeeze(
builder,
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));

let result = builder.reshape(
builder.gruCell(
currentInput,
currentWeight[dir],
currentRecurrentWeight[dir],
currentHidden[dir],
hiddenSize,
{
bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir],
resetAfter: options.resetAfter,
layout: options.layout,
activations: options.activations
}),
[1, batchSize, hiddenSize]);

currentOutput =
(currentOutput ? builder.concat([currentOutput, result], 0) : result);
currentHidden = builder.gruCell(
currentInput,
currentWeight[dir],
currentRecurrentWeight[dir],
currentHidden,
hiddenSize,
{
bias: currentBias[dir],
recurrentBias: currentRecurrentBias[dir],
resetAfter: options.resetAfter,
layout: options.layout,
activations: options.activations
});

if (options.returnSequence) {
// Expand currentHidden of 2D([batchSize, hiddenSize])
// to 4D([steps, numDirections, batchSize, hiddenSize])
const expandedHiddenAs4D = builder.reshape(
currentHidden, [1, 1, batchSize, hiddenSize]);

if (direction == 'forward' || (dir == 0 && direction == 'both')) {
forwardSequence = forwardSequence ?
builder.concat([forwardSequence, expandedHiddenAs4D], 0) :
expandedHiddenAs4D;
} else if (direction == 'backward' || (dir == 1 && direction == 'both')) {
backwardSequence = backwardSequence ?
builder.concat([expandedHiddenAs4D, backwardSequence], 0) :
expandedHiddenAs4D;
}
}
}

hiddenState = currentOutput;
// Expand currentHidden of 2D([batchSize, hiddenSize])
// to 3D([numDirections, batchSize, hiddenSize])
const expandedHiddenAs3D = builder.reshape(
currentHidden, [1, batchSize, hiddenSize]);
outputHidden = outputHidden ?
builder.concat([outputHidden, expandedHiddenAs3D], 0) :
expandedHiddenAs3D;
}

if (options.returnSequence) {
currentOutput = builder.reshape(
currentOutput, [1, numDirections, batchSize, hiddenSize]);
sequence =
(sequence ? builder.concat([sequence, currentOutput], 0) :
currentOutput);
if (options.returnSequence) {
let outputSequence = null;

if (direction == 'forward') {
outputSequence = forwardSequence;
} else if (direction == 'backward') {
outputSequence = backwardSequence;
} else if (direction == 'both') {
// Concat along axis 1 (numDirections dimension)
outputSequence = builder.concat([forwardSequence, backwardSequence], 1);
}
}

return (sequence ? [hiddenState, sequence] : [hiddenState]);
return [outputHidden, outputSequence];
} else {
return [outputHidden];
}
}
</pre>
</details>
Expand Down

0 comments on commit 259b7d3

Please sign in to comment.