From 259b7d3830965d07d023666f10cecfc33be9ec12 Mon Sep 17 00:00:00 2001 From: BruceDai Date: Tue, 14 Jan 2025 16:33:38 +0800 Subject: [PATCH] Fix emulation error of gru by 'backward' and 'both' direction options --- index.bs | 114 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 68 insertions(+), 46 deletions(-) diff --git a/index.bs b/index.bs index 0e955244..f6c975f9 100644 --- a/index.bs +++ b/index.bs @@ -4225,20 +4225,26 @@ partial dictionary MLOpSupportLimits { builder, input, weight, recurrentWeight, steps, hiddenSize, options) { const batchSize = input.shape[1]; const inputSize = input.shape[2]; - const numDirections = (options.direction == 'both' ? 2 : 1); + const direction = options.direction || 'forward'; + const numDirections = (direction == 'both' ? 2 : 1); let hiddenState = options.initialHiddenState; if (!hiddenState) { - const desc = {dataType: 'float32', shape: [numDirections, 1, hiddenSize]}; - const totalSize = numDirections * hiddenSize; + const desc = { + dataType: 'float32', + shape: [numDirections, batchSize, hiddenSize] + }; + const totalSize = numDirections * batchSize * hiddenSize; hiddenState = builder.constant(desc, new Float32Array(totalSize).fill(0)); } - let sequence = null; let currentWeight = []; let currentRecurrentWeight = []; let currentBias = []; let currentRecurrentBias = []; + let forwardSequence = null; + let backwardSequence = null; + let outputHidden = null; for (let dir = 0; dir < numDirections; ++dir) { currentWeight.push(squeeze( @@ -4261,57 +4267,73 @@ partial dictionary MLOpSupportLimits { builder.slice( options.recurrentBias, [dir, 0], [1, 3 * hiddenSize]))) : null); - } - - for (let step = 0; step < steps; ++step) { - let currentHidden = []; - let currentOutput = null; - - for (let dir = 0; dir < numDirections; ++dir) { - currentHidden.push(squeeze( - builder, - builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]))); - } + let currentHidden = squeeze( + builder, + builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize])); - for (let dir = 0; dir < numDirections; ++dir) { - let slice = - (dir == 1 || options.direction == 'backward' ? steps - step - 1 : step); - let currentInput = squeeze( + for (let step = 0; step < steps; ++step) { + const slice = (dir == 1 || direction == 'backward' ? steps - step - 1 : step); + const currentInput = squeeze( builder, builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize])); - let result = builder.reshape( - builder.gruCell( - currentInput, - currentWeight[dir], - currentRecurrentWeight[dir], - currentHidden[dir], - hiddenSize, - { - bias: currentBias[dir], - recurrentBias: currentRecurrentBias[dir], - resetAfter: options.resetAfter, - layout: options.layout, - activations: options.activations - }), - [1, batchSize, hiddenSize]); - - currentOutput = - (currentOutput ? builder.concat([currentOutput, result], 0) : result); + currentHidden = builder.gruCell( + currentInput, + currentWeight[dir], + currentRecurrentWeight[dir], + currentHidden, + hiddenSize, + { + bias: currentBias[dir], + recurrentBias: currentRecurrentBias[dir], + resetAfter: options.resetAfter, + layout: options.layout, + activations: options.activations + }); + + if (options.returnSequence) { + // Expand currentHidden of 2D([batchSize, hiddenSize]) + // to 4D([steps, numDirections, batchSize, hiddenSize]) + const expandedHiddenAs4D = builder.reshape( + currentHidden, [1, 1, batchSize, hiddenSize]); + + if (direction == 'forward' || (dir == 0 && direction == 'both')) { + forwardSequence = forwardSequence ? + builder.concat([forwardSequence, expandedHiddenAs4D], 0) : + expandedHiddenAs4D; + } else if (direction == 'backward' || (dir == 1 && direction == 'both')) { + backwardSequence = backwardSequence ? + builder.concat([expandedHiddenAs4D, backwardSequence], 0) : + expandedHiddenAs4D; + } + } } - hiddenState = currentOutput; + // Expand currentHidden of 2D([batchSize, hiddenSize]) + // to 3D([numDirections, batchSize, hiddenSize]) + const expandedHiddenAs3D = builder.reshape( + currentHidden, [1, batchSize, hiddenSize]); + outputHidden = outputHidden ? + builder.concat([outputHidden, expandedHiddenAs3D], 0) : + expandedHiddenAs3D; + } - if (options.returnSequence) { - currentOutput = builder.reshape( - currentOutput, [1, numDirections, batchSize, hiddenSize]); - sequence = - (sequence ? builder.concat([sequence, currentOutput], 0) : - currentOutput); + if (options.returnSequence) { + let outputSequence = null; + + if (direction == 'forward') { + outputSequence = forwardSequence; + } else if (direction == 'backward') { + outputSequence = backwardSequence; + } else if (direction == 'both') { + // Concat along axis 1 (numDirections dimension) + outputSequence = builder.concat([forwardSequence, backwardSequence], 1); } - } - return (sequence ? [hiddenState, sequence] : [hiddenState]); + return [outputHidden, outputSequence]; + } else { + return [outputHidden]; + } }