Skip to content

Commit

Permalink
fix concat codegen (#1311)
Browse files Browse the repository at this point in the history
using lmhlo_disc.concat if operands are fixed shape
  • Loading branch information
Yancey1989 authored Aug 27, 2024
1 parent fbe39bc commit 58efe1a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 9 deletions.
4 changes: 2 additions & 2 deletions tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ struct LhloConcatenateOpConverter
PatternRewriter& rewriter) const override {
Operation* op = lhloOp.getOperation();
if (!isFixedShape(lhloOp)) return failure();

auto operands = op->getOperands();

// TODO(yancey): support CPU place
if (!placement_utils::isGpuMemRef(operands[0])) return failure();
auto deviceAttr = op->getAttrOfType<StringAttr>(kDiscPlaceAssignment);
if (!deviceAttr || deviceAttr.getValue() != kGpu) return failure();
int num_input_operands = op->getNumOperands() - 1;

SmallVector<Value, 4> ptr_array;
Expand Down
27 changes: 22 additions & 5 deletions tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1305,14 +1305,14 @@ Value elementalLower<lmhlo::ConcatenateOp>(OpBuilder* b, Location loc,

b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front());
if (i == num_input_operands - 1) {
input_index[axis] = b->create<arith::SubIOp>(loc, out_idx, low_bound);
auto operand_memref = op.getOperand(i);
// we expect this branch never be executed
input_index[axis] = b->create<arith::ConstantIndexOp>(loc, 0);
auto ret_value =
check_cache ? createLoadOrUseCachedValue(
loc, b, op.getOperation(), operand_memref,
loc, b, op.getOperation(), op.getOperand(i),
input_index, b->saveInsertionPoint(), lower_config)
: createMaySpecificLoad(*b, loc, op.getOperation(),
operand_memref, input_index,
op.getOperand(i), input_index,
lower_config);
b->create<scf::YieldOp>(loc, ret_value);
} else {
Expand Down Expand Up @@ -1360,7 +1360,24 @@ Value elementalLower<lmhlo_disc::ConcatenateOp>(OpBuilder* b, Location loc,

auto int_ptr =
b->create<memref::LoadOp>(loc, ptr_array, ValueRange{operand_index});
Type ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx));
auto elem_ty = out.getType().cast<MemRefType>().getElementType();
// if elem_ty is bf16
Type ptr_type;
if (elem_ty.isBF16()) {
ptr_type = LLVM::LLVMPointerType::get(FloatType::getBF16(ctx));
} else if (elem_ty.isF16()) {
ptr_type = LLVM::LLVMPointerType::get(FloatType::getF16(ctx));
} else if (elem_ty.isF32()) {
ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx));
} else if (elem_ty.isInteger(32) || elem_ty.isInteger(64) ||
elem_ty.isInteger(8)) {
ptr_type = LLVM::LLVMPointerType::get(
IntegerType::get(ctx, elem_ty.getIntOrFloatBitWidth()));
} else {
op.emitError("unsupported element type for ConcatenateOp");
return Value(nullptr);
}

auto llvm_ptr = b->create<LLVM::IntToPtrOp>(loc, ptr_type, int_ptr);

SmallVector<Value, 4> input_index;
Expand Down
11 changes: 9 additions & 2 deletions tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@

module @main attributes {gpu.container_module} {
func.func @test_concat(%arg0: memref<2x16xf32, #gpu.address_space<global>>, %arg1: memref<2x16xf32, #gpu.address_space<global>>, %out : memref<4x16xf32, #gpu.address_space<global>>) -> memref<4x16xf32, #gpu.address_space<global>> attributes {gpu.kernel} {
// CHECK: lmhlo_disc.concatenate
"lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64 } : (memref<2x16xf32, #gpu.address_space<global>>, memref<2x16xf32, #gpu.address_space<global>>, memref<4x16xf32, #gpu.address_space<global>>) -> ()
// CHECK: memref.alloc() : memref<3xi64>
// CHECK: "disc_ral.get_pointer"(%arg0)
// CHECK: memref.store %0, %alloc[%c0]
// CHECK: "disc_ral.get_pointer"(%arg1)
// CHECK: memref.store %1, %alloc[%c1]
// CHECK: memref.alloc() : memref<3xi64>
// CHECK: "lmhlo_disc.h2d"(%alloc, %alloc_0)
// CHECK: "lmhlo_disc.concatenate"(%arg0, %arg1, %alloc_0, %arg2)
"lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64, disc.device = "gpu"} : (memref<2x16xf32, #gpu.address_space<global>>, memref<2x16xf32, #gpu.address_space<global>>, memref<4x16xf32, #gpu.address_space<global>>) -> ()
return %out : memref<4x16xf32, #gpu.address_space<global>>
}
}

0 comments on commit 58efe1a

Please sign in to comment.