Skip to content

Commit

Permalink
Addressing comments by Hanhan
Browse files Browse the repository at this point in the history
-- Nit code style fixes.
-- Making the test case smaller.
  • Loading branch information
pashu123 committed Aug 12, 2024
1 parent 55f1611 commit 8bc26df
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
auto zero = rewriter.getIndexAttr(0);
int64_t numLoops = rootOp.getLoopIteratorTypes().size();
if (tileSizes.size() > numLoops) {
LLVM_DEBUG(llvm::dbgs() << "tile sizes size " << tileSizes.size()
<< " exceeds the number of loops " << numLoops
<< "\n");
return failure();
}
while (tileSizes.size() < numLoops) {
tileSizes.push_back(zero);
}
tileSizes.resize(numLoops, zero);

scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes);
Expand Down Expand Up @@ -133,7 +134,7 @@ tileRootAndFuseProducers(IRRewriter &rewriter, TilingInterface rootOp,
return tiledResults->tiledAndFusedOps.front();
}

static LogicalResult fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
static void fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {

// Typically, the consumers of the tiled operation are slices of the
// results of the tiled operation. These are expressed in IR using
Expand Down Expand Up @@ -180,7 +181,6 @@ static LogicalResult fuseConsumers(RewriterBase &rewriter, Operation *tiledOp) {
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
candidates);
}
return success();
}

/// Implementation of tile root and fuse producers and consumers greedily.
Expand All @@ -198,7 +198,7 @@ static LogicalResult tileRootAndFuse(IRRewriter &rewriter,
return failure();

if (!onlyFuseProducerInputOperands)
return fuseConsumers(rewriter, tiledOp.value());
fuseConsumers(rewriter, tiledOp.value());

return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,14 @@ func.func @dequant_avgpool(%arg0: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf3
}

// CHECK-REDUCTION-LABEL: func.func @dequant_avgpool(
// CHECK-REDUCTION-SAME: %[[VAL_0:.*]]: tensor<1x320x65x65xi8>) -> tensor<1x320x1x1xf32> {
// CHECK-REDUCTION: %[[VAL_1:.*]] = arith.constant 5 : index
// CHECK-REDUCTION: %[[VAL_2:.*]] = arith.constant 1 : index
// CHECK-REDUCTION: %[[VAL_3:.*]] = arith.constant 65 : index
// CHECK-REDUCTION: %[[VAL_4:.*]] = arith.constant 1.250000e-01 : f32
// CHECK-REDUCTION: %[[VAL_5:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-REDUCTION: %[[VAL_6:.*]] = arith.constant 0 : index
// CHECK-REDUCTION: %[[VAL_7:.*]] = tensor.empty() : tensor<1x320x1x1xf32>
// CHECK-REDUCTION: %[[VAL_8:.*]] = scf.for %[[VAL_9:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_2]] iter_args(%[[VAL_10:.*]] = %[[VAL_7]]) -> (tensor<1x320x1x1xf32>) {
// CHECK-REDUCTION: %[[VAL_11:.*]] = scf.for %[[VAL_12:.*]] = %[[VAL_6]] to %[[VAL_3]] step %[[VAL_1]] iter_args(%[[ITER_ARG:.*]] = %[[VAL_10]]) -> (tensor<1x320x1x1xf32>) {
// CHECK-REDUCTION: %[[VAL_14:.*]] = tensor.extract_slice %[[VAL_0]][0, 0, %[[VAL_9]], %[[VAL_12]]] [1, 320, 1, 5] [1, 1, 1, 1] : tensor<1x320x65x65xi8> to tensor<1x320x1x5xi8>
// CHECK-REDUCTION: %[[VAL_15:.*]] = tensor.empty() : tensor<1x320x1x5xf32>
// CHECK-REDUCTION: %[[VAL_16:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[VAL_14]] : tensor<1x320x1x5xi8>) outs(%[[VAL_15]] : tensor<1x320x1x5xf32>) {
// CHECK-REDUCTION: ^bb0(%[[VAL_17:.*]]: i8, %[[VAL_18:.*]]: f32):
// CHECK-REDUCTION: %[[VAL_19:.*]] = arith.extsi %[[VAL_17]] : i8 to i32
// CHECK-REDUCTION: %[[VAL_20:.*]] = arith.sitofp %[[VAL_19]] : i32 to f32
// CHECK-REDUCTION: %[[VAL_21:.*]] = arith.mulf %[[VAL_20]], %[[VAL_4]] : f32
// CHECK-REDUCTION: linalg.yield %[[VAL_21]] : f32
// CHECK-REDUCTION: } -> tensor<1x320x1x5xf32>
// CHECK-REDUCTION: %[[VAL_22:.*]] = tensor.empty() : tensor<1x5xf32>
// CHECK-REDUCTION: %[[VAL_23:.*]] = linalg.fill ins(%[[VAL_5]] : f32) outs(%[[VAL_22]] : tensor<1x5xf32>) -> tensor<1x5xf32>
// CHECK-REDUCTION: %[[RED:.*]] = linalg.pooling_nchw_sum {lowering_config = #config} ins(%[[VAL_16]], %[[VAL_23]] : tensor<1x320x1x5xf32>, tensor<1x5xf32>) outs(%[[ITER_ARG]] : tensor<1x320x1x1xf32>) -> tensor<1x320x1x1xf32>
// CHECK-REDUCTION: scf.yield %[[RED]] : tensor<1x320x1x1xf32>
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: scf.for
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: scf.for
// CHECK-REDUCTION-SAME: {
// CHECK-REDUCTION: linalg.generic
// CHECK-REDUCTION: %[[POOL:.+]] = linalg.pooling_nchw_sum
// CHECK-REDUCTION: scf.yield %[[POOL]]
// CHECK-REDUCTION: }
// CHECK-REDUCTION: }
// CHECK-REDUCTION: scf.yield %[[VAL_11]] : tensor<1x320x1x1xf32>
// CHECK-REDUCTION: }
// CHECK-REDUCTION: return %[[VAL_8]] : tensor<1x320x1x1xf32>
// CHECK-REDUCTION: }

0 comments on commit 8bc26df

Please sign in to comment.