Skip to content

Commit

Permalink
[js/webgpu] validate transpose perm if specified (#23197)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
xhcao authored Jan 1, 2025
1 parent 0b87bcc commit a3833a5
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
}

const validateInputs = (inputs: readonly TensorView[]): void => {
const validateInputs = (inputs: readonly TensorView[], perm: readonly number[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Transpose requires 1 input.');
}

if (perm.length !== 0 && perm.length !== inputs[0].dims.length) {
throw new Error(`perm size ${perm.length} does not match input rank ${inputs[0].dims.length}`);
}
};

const getAdjustedPerm = (inputRank: number, perm: number[]): number[] =>
perm && perm.length !== inputRank ? [...new Array(inputRank).keys()].reverse() : perm;
perm.length !== 0 ? perm : [...new Array(inputRank).keys()].reverse();

const getOutputShape = (inputShape: readonly number[], perm: number[]): readonly number[] =>
ShapeUtil.sortBasedOnPerm(inputShape, getAdjustedPerm(inputShape.length, perm));
Expand Down Expand Up @@ -191,7 +195,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
};

export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
validateInputs(context.inputs, attributes.perm);
context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
};

Expand Down

0 comments on commit a3833a5

Please sign in to comment.