Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][IGEMM] Add new pass for IGEMM transformation with reshape propagation #18161

Merged
merged 5 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ iree_compiler_cc_library(
"ConvertBf16ArithToF32.cpp",
"ConvertBf16ToUInt16Buffers.cpp",
"ConvertToDestinationPassingStylePass.cpp",
"ConvolutionToIGEMM.cpp",
"DecomposeAffineOpsPass.cpp",
"DecomposeConvolutionToLowerDimOps.cpp",
"DecomposeLinalgGeneric.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ iree_cc_library(
"ConvertBf16ArithToF32.cpp"
"ConvertBf16ToUInt16Buffers.cpp"
"ConvertToDestinationPassingStylePass.cpp"
"ConvolutionToIGEMM.cpp"
"DecomposeAffineOpsPass.cpp"
"DecomposeConvolutionToLowerDimOps.cpp"
"DecomposeLinalgGeneric.cpp"
Expand Down
104 changes: 104 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/ConvolutionToIGEMM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
// Copyright 2022 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/PassDetail.h"
#include "iree/compiler/Codegen/Common/Passes.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Passes.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::iree_compiler {

namespace {

using iree_compiler::IREE::LinalgExt::IREELinalgExtDialect;

class ConvolutionToIGEMMPass
: public ConvolutionToIGEMMBase<ConvolutionToIGEMMPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();

// Rewrite convolutions into a im2col and GEMM.
{
auto conv2dToIm2colControlFn = [](Operation *conv) {
// Don't transform convolutions that have a preset lowering config.
if (getLoweringConfig(conv)) {
return false;
}
return true;
};
RewritePatternSet patterns(&getContext());
iree_compiler::IREE::LinalgExt::populateConv2DToIm2colOpPatterns(
patterns, conv2dToIm2colControlFn);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}

// The im2col transformation collapses some of the dimensions of the
// convolution operands. Try to push the reshape ops towards the boundaries
// of the function and fold with interface tensor ops.
//
// TODO(Max191): Allow for the im2col op to have multiple M dimensions, and
// generate a multi-M dim contraction instead of collapsing and
// propagating reshapes. It should ultimately become a pass option to
// decide whether to collapse the contraction dimensions into a single
// M/N/K dimension.
{
RewritePatternSet bubbleCollapseShapePatterns(context);
linalg::ControlFusionFn bubbleUpExpansionControlFn =
[](OpOperand *fusedOperand) {
Operation *producer = fusedOperand->get().getDefiningOp();
Operation *consumer = fusedOperand->getOwner();

// Block only if one of the operations has a lowering configuration
// which means it likely expects tiling specific to its original
// shape.
if (getLoweringConfig(producer) || getLoweringConfig(consumer)) {
return false;
}
return true;
};
linalg::populateFoldReshapeOpsByCollapsingPatterns(
bubbleCollapseShapePatterns, bubbleUpExpansionControlFn);
// Add patterns to do some additional cleanup (on top of canonicalizations
// that can be done later) of reshape ops.
tensor::populateFoldTensorEmptyPatterns(bubbleCollapseShapePatterns);
linalg::FillOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::CollapseShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
tensor::EmptyOp::getCanonicalizationPatterns(bubbleCollapseShapePatterns,
context);
tensor::ExpandShapeOp::getCanonicalizationPatterns(
bubbleCollapseShapePatterns, context);
populateReshapeToInterfaceTensorPatterns(bubbleCollapseShapePatterns);
if (failed(applyPatternsAndFoldGreedily(
getOperation(), std::move(bubbleCollapseShapePatterns)))) {
return signalPassFailure();
}
}
}
};

} // namespace

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvolutionToIGEMMPass() {
return std::make_unique<ConvolutionToIGEMMPass>();
}

} // namespace mlir::iree_compiler
4 changes: 4 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvertToDestinationPassingStylePass(
bool useWARForCooperativeMatrixCodegen = false);

/// Converts convolution operations to a GEMM with an im2col op on the image.
std::unique_ptr<InterfacePass<FunctionOpInterface>>
createConvolutionToIGEMMPass();

// Decompose affine.apply operations into sub affine.apply that can be
// hoisted in different loops.
std::unique_ptr<Pass> createDecomposeAffineOpsPass();
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def ConvertToDestinationPassingStyle :
];
}

def ConvolutionToIGEMM :
InterfacePass<"iree-codegen-convolution-to-igemm", "mlir::FunctionOpInterface"> {
let summary =
"Transforms convolution operations into an implicit GEMM format.";
let constructor = "mlir::iree_compiler::createConvolutionToIGEMMPass()";
}

def DecomposeAffineOps: Pass<"decompose-affine-ops"> {
let summary = "Decompose `affine.apply` operations into sub `affine.apply`";
let description = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ iree_lit_test_suite(
"convert_bf16_to_uint16_buffers.mlir",
"convert_bf16_arith_to_f32.mlir",
"convert_to_destination_passing_style.mlir",
"convolution_to_igemm.mlir",
"convolutions.mlir",
"erase_dead_alloc_and_stores.mlir",
"decompose_affine_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"convert_bf16_arith_to_f32.mlir"
"convert_bf16_to_uint16_buffers.mlir"
"convert_to_destination_passing_style.mlir"
"convolution_to_igemm.mlir"
"convolutions.mlir"
"decompose_affine_ops.mlir"
"decompose_conv2d.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// RUN: iree-opt --split-input-file --pass-pipeline="builtin.module(func.func(iree-codegen-convolution-to-igemm),canonicalize,cse)" %s | FileCheck %s

#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
func.func public @conv_with_consumer(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf16> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%0 = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%1 = tensor.empty() : tensor<1x14x14x16xf16>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%0 : tensor<1x14x14x16xf32>) outs(%1 : tensor<1x14x14x16xf16>) {
^bb0(%in: f32, %out: f16):
%3 = arith.truncf %in : f32 to f16
linalg.yield %3 : f16
} -> tensor<1x14x14x16xf16>
return %2 : tensor<1x14x14x16xf16>
}
// CHECK: func.func public @conv_with_consumer
// CHECK: %[[IM2COL:.+]] = iree_linalg_ext.im2col
// CHECK-SAME: : tensor<1x196x36xf32>) -> tensor<1x196x36xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: -> tensor<1x196x16xf32>
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK: %[[TRUNCF:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[TRUNCF]] {{\[}}[0], [1, 2], [3]] output_shape [1, 14, 14, 16] : tensor<1x196x16xf16> into tensor<1x14x14x16xf16>
// CHECK: return %[[EXPANDED]] : tensor<1x14x14x16xf16>

// -----

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>,
#hal.descriptor_set.binding<2, storage_buffer>
]>
]>
#config = #iree_gpu.lowering_config<{thread = [2, 16], subgroup = [2, 16]}>
#map = affine_map<(d0, d1) -> (d0, d1)>
module {
func.func @fold_with_interface_tensor() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>>
%2 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(2) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
%3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [1, 16, 16, 4], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<1x16x16x4xf32>> -> tensor<1x16x16x4xf32>
%4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [3, 3, 4, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<3x3x4x16xf32>> -> tensor<3x3x4x16xf32>
%5 = flow.dispatch.tensor.load %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>> -> tensor<1x14x14x16xf32>
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%5 : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%6 = linalg.conv_2d_nhwc_hwcf
{dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64> }
ins(%3, %4: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
flow.dispatch.tensor.store %6, %2, offsets = [0, 0, 0, 0], sizes = [1, 14, 14, 16], strides = [1, 1, 1, 1] : tensor<1x14x14x16xf32> -> !flow.dispatch.tensor<writeonly:tensor<1x14x14x16xf32>>
return
}
}

// CHECK: func.func @fold_with_interface_tensor
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x16x16x4xf32>
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<36x16xf32>
// CHECK-DAG: %[[RES:.+]] = flow.dispatch.tensor.load {{.*}} -> tensor<1x196x16xf32>
// CHECK-DAG: %[[IM2COL:.+]] = iree_linalg_ext.im2col {{.*}} ins(%[[LHS]] : tensor<1x16x16x4xf32>){{.*}}-> tensor<1x196x36xf32>
// CHECK-DAG: %[[FILL:.+]] = linalg.fill {{.*}}outs(%[[RES]] : tensor<1x196x16xf32>)
// CHECK: %[[MATMUL:.+]] = linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%[[IM2COL]], %[[RHS]] : tensor<1x196x36xf32>, tensor<36x16xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x196x16xf32>) {
// CHECK: flow.dispatch.tensor.store %[[MATMUL]]

// -----

#map = affine_map<(d0, d1, d2, d3)->(d0, d1, d2, d3)>
#config = #iree_codegen.lowering_config<tile_sizes = [[0, 1, 4, 32], [0, 1, 2, 4], [0, 0, 0, 0, 1, 1, 4], [0, 1, 0, 0]]>
func.func public @conv_with_lowering_config(%arg0: tensor<1x16x16x4xf32>, %arg1: tensor<3x3x4x16xf32>) -> tensor<1x14x14x16xf32> {
%cst = arith.constant 0.0 : f32
%empty = tensor.empty() : tensor<1x14x14x16xf32>
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
%0 = linalg.conv_2d_nhwc_hwcf {lowering_config = #config,
dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
ins(%arg0, %arg1: tensor<1x16x16x4xf32>, tensor<3x3x4x16xf32>)
outs(%fill: tensor<1x14x14x16xf32>) -> tensor<1x14x14x16xf32>
return %0 : tensor<1x14x14x16xf32>
}
// CHECK: func.func public @conv_with_lowering_config
// CHECK-NOT: iree_linalg_ext.im2col
// CHECK: linalg.conv_2d_nhwc_hwcf
// CHECK-SAME: lowering_config
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ static Value createMul(Location loc, Value x, Value y, OpBuilder &builder) {

namespace {

using ControlFnTy = std::optional<std::function<bool(Operation *)>>;

// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
// and linalg.matmul.
//
Expand Down Expand Up @@ -75,8 +77,16 @@ class ConvertConv2DNhwcHwcf final
public:
using OpRewritePattern::OpRewritePattern;

ConvertConv2DNhwcHwcf(MLIRContext *context, ControlFnTy controlFn)
: OpRewritePattern<linalg::Conv2DNhwcHwcfOp>(context),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp,
PatternRewriter &rewriter) const override {
if (controlFn.has_value() && !controlFn.value()(convOp)) {
return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
}

auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
Expand Down Expand Up @@ -181,6 +191,9 @@ class ConvertConv2DNhwcHwcf final

return success();
}

private:
ControlFnTy controlFn;
};

// For nchw, because the channels are to the left of the image shape dimensions,
Expand All @@ -192,8 +205,16 @@ class ConvertConv2DNchwFchw final
public:
using OpRewritePattern::OpRewritePattern;

ConvertConv2DNchwFchw(MLIRContext *context, ControlFnTy controlFn)
: OpRewritePattern<linalg::Conv2DNchwFchwOp>(context),
controlFn(controlFn) {}

LogicalResult matchAndRewrite(linalg::Conv2DNchwFchwOp convOp,
PatternRewriter &rewriter) const override {
if (controlFn.has_value() && !controlFn.value()(convOp)) {
return rewriter.notifyMatchFailure(convOp, "controlFn failed.");
}

auto inputType = llvm::cast<ShapedType>(convOp.getInputs()[0].getType());
auto filterType = llvm::cast<ShapedType>(convOp.getInputs()[1].getType());
auto outputType = llvm::cast<ShapedType>(convOp.getOutputs()[0].getType());
Expand Down Expand Up @@ -296,18 +317,19 @@ class ConvertConv2DNchwFchw final

return success();
}

private:
ControlFnTy controlFn;
};

struct ConvertConv2DToIm2ColOpPass
: ConvertConv2DToIm2ColOpBase<ConvertConv2DToIm2ColOpPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry
.insert<tensor::TensorDialect, IREE::LinalgExt::IREELinalgExtDialect>();
registry.insert<tensor::TensorDialect, IREELinalgExtDialect>();
}
void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(context);
populateConv2DToIm2colOpPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
Expand All @@ -317,6 +339,12 @@ struct ConvertConv2DToIm2ColOpPass

} // namespace

void populateConv2DToIm2colOpPatterns(RewritePatternSet &patterns,
ControlFnTy controlFn) {
patterns.insert<ConvertConv2DNhwcHwcf, ConvertConv2DNchwFchw>(
patterns.getContext(), std::move(controlFn));
}

std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass() {
return std::make_unique<ConvertConv2DToIm2ColOpPass>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_
#define IREE_COMPILER_DIALECT_LINALGEXT_TRANSFORMS_PASSES_H_

#include <optional>
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
Expand Down Expand Up @@ -49,6 +50,12 @@ createDecomposeWinogradTransformPass();
std::unique_ptr<InterfacePass<mlir::FunctionOpInterface>>
createConvertConv2DToIm2ColOpPass();

// Patterns to convert linalg convolution ops into a gemm with an im2col
// op and reshapes on the inputs.
void populateConv2DToIm2colOpPatterns(
RewritePatternSet &patterns,
std::optional<std::function<bool(Operation *)>> controlFn = std::nullopt);

// Creates a pass to convert linalg convolution ops into a sequence of
// linalg_ext.winograd.* ops and linalg.batch_matmul ops using the winograd
// tranformation.
Expand Down
Loading
Loading