From 58efe1a138cf32b11380710b51593642a0db18a4 Mon Sep 17 00:00:00 2001 From: Yan Xu Date: Tue, 27 Aug 2024 14:56:44 +0800 Subject: [PATCH] fix concat codegen (#1311) using lmhlo_disc.concat if operands are fixed shape --- .../disc/transforms/disc_lhlo_rewriter.cc | 4 +-- .../disc/transforms/lhlo_elemental_utils.cc | 27 +++++++++++++++---- .../transforms/tests/disc-lhlo-rewrite.mlir | 11 ++++++-- 3 files changed, 33 insertions(+), 9 deletions(-) diff --git a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc index 5a08f138708..acb84e0b513 100755 --- a/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc +++ b/tao_compiler/mlir/disc/transforms/disc_lhlo_rewriter.cc @@ -87,11 +87,11 @@ struct LhloConcatenateOpConverter PatternRewriter& rewriter) const override { Operation* op = lhloOp.getOperation(); if (!isFixedShape(lhloOp)) return failure(); - auto operands = op->getOperands(); // TODO(yancey): support CPU place - if (!placement_utils::isGpuMemRef(operands[0])) return failure(); + auto deviceAttr = op->getAttrOfType(kDiscPlaceAssignment); + if (!deviceAttr || deviceAttr.getValue() != kGpu) return failure(); int num_input_operands = op->getNumOperands() - 1; SmallVector ptr_array; diff --git a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc index fcf33a6e4e5..1c9c16c2a50 100644 --- a/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc +++ b/tao_compiler/mlir/disc/transforms/lhlo_elemental_utils.cc @@ -1305,14 +1305,14 @@ Value elementalLower(OpBuilder* b, Location loc, b->setInsertionPointToEnd(&if_inbound_ops[i].getElseRegion().front()); if (i == num_input_operands - 1) { - input_index[axis] = b->create(loc, out_idx, low_bound); - auto operand_memref = op.getOperand(i); + // we expect this branch never be executed + input_index[axis] = b->create(loc, 0); auto ret_value = check_cache ? createLoadOrUseCachedValue( - loc, b, op.getOperation(), operand_memref, + loc, b, op.getOperation(), op.getOperand(i), input_index, b->saveInsertionPoint(), lower_config) : createMaySpecificLoad(*b, loc, op.getOperation(), - operand_memref, input_index, + op.getOperand(i), input_index, lower_config); b->create(loc, ret_value); } else { @@ -1360,7 +1360,24 @@ Value elementalLower(OpBuilder* b, Location loc, auto int_ptr = b->create(loc, ptr_array, ValueRange{operand_index}); - Type ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx)); + auto elem_ty = out.getType().cast().getElementType(); + // if elem_ty is bf16 + Type ptr_type; + if (elem_ty.isBF16()) { + ptr_type = LLVM::LLVMPointerType::get(FloatType::getBF16(ctx)); + } else if (elem_ty.isF16()) { + ptr_type = LLVM::LLVMPointerType::get(FloatType::getF16(ctx)); + } else if (elem_ty.isF32()) { + ptr_type = LLVM::LLVMPointerType::get(FloatType::getF32(ctx)); + } else if (elem_ty.isInteger(32) || elem_ty.isInteger(64) || + elem_ty.isInteger(8)) { + ptr_type = LLVM::LLVMPointerType::get( + IntegerType::get(ctx, elem_ty.getIntOrFloatBitWidth())); + } else { + op.emitError("unsupported element type for ConcatenateOp"); + return Value(nullptr); + } + auto llvm_ptr = b->create(loc, ptr_type, int_ptr); SmallVector input_index; diff --git a/tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir b/tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir index 1a4b046c113..a718a5becc6 100644 --- a/tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir +++ b/tao_compiler/mlir/disc/transforms/tests/disc-lhlo-rewrite.mlir @@ -2,8 +2,15 @@ module @main attributes {gpu.container_module} { func.func @test_concat(%arg0: memref<2x16xf32, #gpu.address_space>, %arg1: memref<2x16xf32, #gpu.address_space>, %out : memref<4x16xf32, #gpu.address_space>) -> memref<4x16xf32, #gpu.address_space> attributes {gpu.kernel} { - // CHECK: lmhlo_disc.concatenate - "lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64 } : (memref<2x16xf32, #gpu.address_space>, memref<2x16xf32, #gpu.address_space>, memref<4x16xf32, #gpu.address_space>) -> () + // CHECK: memref.alloc() : memref<3xi64> + // CHECK: "disc_ral.get_pointer"(%arg0) + // CHECK: memref.store %0, %alloc[%c0] + // CHECK: "disc_ral.get_pointer"(%arg1) + // CHECK: memref.store %1, %alloc[%c1] + // CHECK: memref.alloc() : memref<3xi64> + // CHECK: "lmhlo_disc.h2d"(%alloc, %alloc_0) + // CHECK: "lmhlo_disc.concatenate"(%arg0, %arg1, %alloc_0, %arg2) + "lmhlo.concatenate"(%arg0, %arg1, %out) { dimension = 0 : i64, disc.device = "gpu"} : (memref<2x16xf32, #gpu.address_space>, memref<2x16xf32, #gpu.address_space>, memref<4x16xf32, #gpu.address_space>) -> () return %out : memref<4x16xf32, #gpu.address_space> } } \ No newline at end of file