Skip to content

Commit

Permalink
Get rid of that weird zero basis hack
Browse files Browse the repository at this point in the history
Now that there's an upstream PR that allows affine.delineraize_index
to clamp, use that instead of the hack I had.
  • Loading branch information
krzysz00 committed Nov 18, 2024
1 parent 714a0b6 commit 47b71e4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -647,17 +647,14 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
getSubgroupSize() / intrinsicLayoutThreadBound);
}

// Add a `0` at the front of the distribution sizes so that
// `affine.delinearize_index` clamp its output (we'll throw away the first
// result).
distributionThreadSizes.insert(distributionThreadSizes.begin(), 0);

// Obtain the offsets from delinearization along the distributionThreadSizes.
// Use a delinearize without outer bound and throw away its initial result
// to get clamping behavior.
SmallVector<OpFoldResult> tileOffsets =
builder
.create<affine::AffineDelinearizeIndexOp>(
loc, getValueOrCreateConstantIndexOp(builder, loc, threadId),
distributionThreadSizes)
distributionThreadSizes, /*hasOuterBound=*/false)
->getResults()
.drop_front();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x16xf32>, %rhs: t
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1]
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
Expand Down Expand Up @@ -426,7 +426,7 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x16x4x
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (0, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
Expand Down Expand Up @@ -462,12 +462,12 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor<
// CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]
// CHECK-SAME: %[[ACC:[A-Za-z0-9]+]]
// CHECK: scf.forall (%[[THREAD_ID:.+]]) in (256) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>)
// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 4, 16)
// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 16)
// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]]
// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1]
// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (0, 2, 2, 4, 16)
// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (2, 2, 4, 16)
// CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]]
// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1]
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]]
Expand Down

0 comments on commit 47b71e4

Please sign in to comment.