Skip to content

Commit

Permalink
Merge branch 'main' into matthias.remove_make_fx
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd authored Jan 13, 2025
2 parents fd2936c + 9a167e2 commit 3390028
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 57 deletions.
4 changes: 2 additions & 2 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
}
// Replace the original index with the index specified
// by the scatter.
yieldVals[dim] = b.create<arith::TruncIOp>(
loc, rewriter.getI32Type(), extractIndexValue);
yieldVals[dim] = convertScalarToDtype(
rewriter, loc, extractIndexValue, rewriter.getI32Type());
yieldVals.push_back(extractSrcValue);
b.create<linalg::YieldOp>(loc, yieldVals);
})
Expand Down
54 changes: 44 additions & 10 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5513,11 +5513,7 @@ class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape)));
}

rewriter.replaceOpWithNewOp<tensor::CastOp>(
op, resultTy,
// OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
// op.getType()),
result);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, result);

return success();
}
Expand Down Expand Up @@ -6451,11 +6447,7 @@ ConvertAtenOp<Aten__InterpolateSizeListScaleListOp>::matchAndRewrite(
tosa::getConstTensor<int32_t>(rewriter, op,
/*vec=*/{0, 3, 1, 2},
/*shape=*/{static_cast<int32_t>(4)});
// SmallVector<int64_t> transposedOutputShape(
// {transposedResizedOpShape[0], transposedResizedOpShape[3],
// transposedResizedOpShape[1], transposedResizedOpShape[2]});
// auto transposedOutputType = RankedTensorType::get(
// makeShapeLLVMCompatible(transposedOutputShape), inputElemTy);

rewriter
.replaceOpWithNewOp<tosa::TransposeOp>(
op, getTypeConverter()->convertType(resultType), resizeOpResult,
Expand Down Expand Up @@ -8212,6 +8204,47 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
return success();
}

// Legalization for aten.expm1
template <>
LogicalResult ConvertAtenOp<AtenExpm1Op>::matchAndRewrite(
AtenExpm1Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// expm1 formula:
// yi = exp(x) - 1
// Note: This lowering might not provide as great precision as aten.expm1
// since TOSA doesn't have a built-in expm1 op.
auto self = adaptor.getSelf();

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");

auto resultType =
dyn_cast<TensorType>(typeConverter->convertType(op.getType()));
auto resultElemTy = resultType.getElementType();

if (!isa<mlir::FloatType>(resultElemTy))
return rewriter.notifyMatchFailure(
op, "Only floating-point datatype result types are supported");

// If input is not a float type then cast it to result element type
auto selfElemTy = selfType.getElementType();
if (!isa<mlir::FloatType>(selfElemTy))
self = tosa::promoteType(rewriter, self, resultType);

auto one =
tosa::getConstTensor<float>(rewriter, op, 1.0f, {}, resultElemTy).value();

auto expOp = rewriter.create<tosa::ExpOp>(op->getLoc(), resultType, self);

auto result = rewriter.create<tosa::SubOp>(op->getLoc(), resultType,
expOp.getResult(), one);

rewriter.replaceOp(op, {result.getResult()});

return success();
}

// Legalization for aten.tan
template <>
LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
Expand Down Expand Up @@ -8805,6 +8838,7 @@ std::set<StringRef> torch::populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_ATENOP_PATTERN(AtenLogitOp);
INSERT_ATENOP_PATTERN(AtenLog1pOp);
INSERT_ATENOP_PATTERN(AtenLog10Op);
INSERT_ATENOP_PATTERN(AtenExpm1Op);
INSERT_ATENOP_PATTERN(AtenTanOp);
INSERT_ATENOP_PATTERN(AtenUnfoldOp);
#undef INSERT_ATENOP_PATTERN
Expand Down
78 changes: 56 additions & 22 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,42 +264,68 @@ std::optional<Value> getConstTensor<float>(PatternRewriter &rewriter,
return const_op.getResult();
}

static LogicalResult checkValidityOfCast(Type src, Type dest) {
// Valid TOSA casting pairs according to TOSA spec:
// https://www.mlplatform.org/tosa/tosa_spec.html#_cast
// Note: currently TOSA doesn't support casting to and from I64 and F64
[[maybe_unused]] static LogicalResult checkValidityOfCast(Type src, Type dest) {
// clang-format off
if ((src == dest) ||
// int64 -> *
(src.isInteger(64) && dest.isInteger(32)) ||
(src.isInteger(64) && dest.isInteger(8)) ||
(src.isInteger(64) && dest.isInteger(1)) ||
(src.isInteger(64) && dest.isF32()) ||
// int32 -> *
(src.isInteger(32) && dest.isInteger(64)) ||
(src.isInteger(32) && dest.isInteger(16)) ||
(src.isInteger(32) && dest.isInteger(8)) ||
(src.isInteger(32) && dest.isInteger(1)) ||
(src.isInteger(32) && dest.isF32()) ||
(src.isInteger(32) && dest.isF16()) ||
(src.isInteger(32) && dest.isBF16()) ||
// int16 -> *
(src.isInteger(16) && dest.isInteger(32)) ||
(src.isInteger(16) && dest.isInteger(8)) ||
(src.isInteger(16) && dest.isInteger(1)) ||
(src.isInteger(16) && dest.isBF16()) ||
(src.isInteger(16) && dest.isF32()) ||
(src.isInteger(16) && dest.isF16()) ||
// int8 -> *
(src.isInteger(8) && dest.isInteger(32)) ||
(src.isInteger(8) && dest.isInteger(16)) ||
(src.isInteger(8) && dest.isInteger(1)) ||
(src.isInteger(8) && dest.isBF16()) ||
(src.isInteger(8) && dest.isF32()) ||
(src.isInteger(8) && dest.isF16()) ||
// int1 -> *
(src.isInteger(1) && dest.isInteger(64)) ||
(src.isInteger(1) && dest.isF32()) ||
// f64 -> *
(src.isF64() && dest.isF32()) ||
(src.isF64() && dest.isBF16()) ||
(src.isInteger(1) && dest.isInteger(32)) ||
(src.isInteger(1) && dest.isInteger(16)) ||
(src.isInteger(1) && dest.isInteger(8)) ||
// f32 -> *
(src.isF32() && dest.isF64()) ||
(src.isF32() && dest.isInteger(32)) ||
(src.isF32() && dest.isInteger(16)) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isBF16()) ||
(src.isF32() && dest.isF16()) ||
(src.isF32() && dest.isInteger(8)) ||
(src.isF32() && dest.isInteger(64)) ||
(src.isF32() && dest.isInteger(1)) ||
(src.isF32() && dest.isFloat8E4M3()) ||
(src.isF32() && dest.isFloat8E5M2()) ||
// f16 -> *
(src.isF16() && dest.isInteger(32)) ||
(src.isF16() && dest.isInteger(16)) ||
(src.isF16() && dest.isInteger(8)) ||
(src.isF16() && dest.isBF16()) ||
(src.isF16() && dest.isF32()) ||
(src.isF16() && dest.isFloat8E4M3()) ||
(src.isF16() && dest.isFloat8E5M2()) ||
// bf16 -> *
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(32)) ||
(src.isBF16() && dest.isF32())) {
(src.isBF16() && dest.isInteger(16)) ||
(src.isBF16() && dest.isInteger(8)) ||
(src.isBF16() && dest.isF32()) ||
(src.isBF16() && dest.isFloat8E4M3()) ||
(src.isBF16() && dest.isFloat8E5M2()) ||
// fp8e4m3 -> *
(src.isFloat8E4M3() && dest.isBF16()) ||
(src.isFloat8E4M3() && dest.isF32()) ||
(src.isFloat8E4M3() && dest.isF16()) ||
// fp8e5m2 -> *
(src.isFloat8E5M2() && dest.isBF16()) ||
(src.isFloat8E5M2() && dest.isF32()) ||
(src.isFloat8E5M2() && dest.isF16())) {
return success();
}
// clang-format on
Expand All @@ -313,9 +339,17 @@ LogicalResult tosaCastTensorToType(PatternRewriter &rewriter, Operation *op,
Type srcElemTy = dyn_cast<TensorType>(src.getType()).getElementType();
Type destElemTy = dyn_cast<TensorType>(destType).getElementType();

if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
return rewriter.notifyMatchFailure(
op, "casting to result dtype is invalid or unsupported");
// Temporarily disable checkValidityOfCast as it's currently strictly
// following TOSA spec and might cause many e2e tests to fail. This is because
// even though there are some casting pairs that are not congruent to TOSA
// spec, they are still permissible. TOSA validation should flag these illegal
// constructs in a per-profile manner. This strict validity check will be
// enabled later in a potential `--strict` mode which checks for strict
// casting only when needed (the default value of `--strict` mode will be
// off).
// if (failed(checkValidityOfCast(srcElemTy, destElemTy)))
// return rewriter.notifyMatchFailure(
// op, "casting to result dtype is invalid or unsupported");

if (destElemTy.isInteger(1)) {
auto srcType = dyn_cast<TensorType>(src.getType());
Expand Down
44 changes: 21 additions & 23 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,7 @@
"ReplicationPad2dModule_top0",
"ScalarImplicitFloatModule_basic",
# REMOVE WHEN ENABLE_GQA IS ADDED
"ScatterAddDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down Expand Up @@ -1704,12 +1705,31 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"FullModuleFalsePinMemory_basic",
"FullModuleInt2D_basic",
"MaskedFillScalarFloatValueModule_basic",
"MaskedFillScalarFloatValueStaticModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"Threshold3dIntModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"BmmFloat16Module_basic",
"ElementwiseRreluWithNoiseTrainStaticModule_basic",
"Unfold_Module_Rank_4",
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_basic",
"ElementwiseErfIntModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseIntTensorLtFloatScalarModule_basic",
"ElementwiseSigmoidIntModule_basic",
"ElementwiseSpecialExpm1IntModule_basic",
"ElementwiseSpecialExpm1Module_basic",
"ElementwiseTanIntModule_basic",
"ElementwiseTanModule_basic",
"ElementwiseUnaryIntModule_basic",
Expand Down Expand Up @@ -3321,7 +3341,6 @@
"LayerNormFwAndBwModule_basic",
"LayerNormManualFwAndBwModule_basic",
"SelfAttentionFwAndBwModule_basic",
"Threshold3dIntModule_basic",
"ElementwiseCopysignModule_basic",
"ElementwiseSignbitModule_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
Expand Down Expand Up @@ -3370,12 +3389,9 @@
"TensorsConcatComplex64FloatModule_basic",
"TimeOutModule_basic",
"TrilIndicesAllZerosModule_basic",
"TrilIndicesModule_basic",
"TrilIndicesNegativeOffsetModule_basic",
"TrilIndicesOfssetGreaterThanRowModule_basic",
"TriuIndicesAllZerosModule_basic",
"TriuIndicesModule_basic",
"TriuIndicesNegativeOffsetModule_basic",
"TypeConversionUint8ToF32Module_basic",
"WeightNormInterfaceModule_basic",
"AdaptiveAvgPool3dDynamicNoBatch_basic",
Expand Down Expand Up @@ -3405,8 +3421,6 @@
"AtenComplexViewModule_basic",
"AtenEmbeddingBagStaticModule_basic",
"AtenEmbeddingBagSumExample_basic",
"AtenEyeMModuleInt2D_basic",
"AtenEyeModuleInt2D_basic",
"AtenFloatScalarModule_basic",
"AtenIntBoolOpConstFalseModule_basic",
"AtenIntBoolOpConstTrueModule_basic",
Expand Down Expand Up @@ -3441,11 +3455,8 @@
"AvgPool2dIntModule_basic",
"AvgPool2dStaticModule_basic",
"BernoulliFloatModule_basic",
"BernoulliModule_basic",
"BernoulliOnesModule_basic",
"BernoulliPModule_basic",
"BernoulliTensorModule_basic",
"BernoulliZerosModule_basic",
"BincountMinlengthModule_basic",
"BincountModule_basic",
"BincountStaticSizeModule_basic",
Expand Down Expand Up @@ -3527,23 +3538,16 @@
"ElementwiseCoshModule_basic",
"ElementwiseDequantizePerChannelModule_basic",
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"ElementwiseMulTensorComplexDiffModule_basic",
"ElementwiseMulTensorComplexModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseSinhIntModule_basic",
"ElementwiseSinhModule_basic",
"ElementwiseSpecialExpm1IntModule_basic",
"ElementwiseSpecialExpm1Module_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"ElementwiseWhereScalarOtherStaticModule_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
"FullLikeModuleInt2D_basic",
"FullLikeModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
Expand Down Expand Up @@ -3629,8 +3633,6 @@
"NativeGroupNormBackwardModule_basic",
"NeFloatIntModule_basic",
"NeIntModule_basic",
"NewFullModuleInt2D_basic",
"NewFullModuleInt3D_basic",
"NllLossModuleBackward1DMeanWeight_basic",
"NllLossModuleBackward1DMean_basic",
"NllLossModuleBackward1DSumWeight_basic",
Expand All @@ -3643,7 +3645,6 @@
"NormalFunctionalModule_basic",
"NumelModule_basic",
"NumelZeroRankModule_basic",
"OnesLikeModule_falsePinMemory",
"PowIntIntModule_basic",
"PrimMaxIntModule_basic",
"PrimMinIntDynamicModule_basic",
Expand Down Expand Up @@ -3739,15 +3740,12 @@
"TorchPrimLoopWhileLikeModule_basic",
"TraceModule_empty",
"TraceUnsignedIntModule_empty",
"TypeConversionI1ToF64Module_basic",
"TypeConversionI1ToI32Module_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"UpSampleNearest2dBackwardScalesNone_basic",
"UpSampleNearest2dBackward_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
"ViewSizeFromOtherTensor_basic",
"VisionTransformerModule_basic",
"ZerosLikeModule_falsePinMemory",
# Unexpected failures due to new PyTorch version update
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
Expand Down Expand Up @@ -4510,7 +4508,6 @@
"QuantizedReluUint8_basic",
"QuantizedSingleLayer_basic",
"RandIntDtypeModule_basic",
"RandIntLowDtypeModule_basic",
"RandIntModule_basic",
"RandIntPinMemoryModule_basic",
"RandLikeDtypeModule_basic",
Expand Down Expand Up @@ -4594,6 +4591,7 @@
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScatterAddDynamicModule_basic",
"ScatterReduceFloatMaxModule",
"ScatterReduceFloatMaxModuleIncludeSelf",
"ScatterReduceFloatMeanModule",
Expand Down
25 changes: 25 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,31 @@ def ScatterAddStaticModule_basic(module, tu: TestUtils):
# ==============================================================================


class ScatterAddDynamicModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
([-1, -1, -1], torch.float32, True),
]
)
def forward(self, input, index, src):
return torch.ops.aten.scatter_add(input, 0, index, src)


@register_test_case(module_factory=lambda: ScatterAddDynamicModule())
def ScatterAddDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 8, 6), tu.randint(2, 4, 3, high=4), tu.rand(5, 8, 6))


# ==============================================================================


class ScatterReduceFloatModule(torch.nn.Module):
include_self: bool
reduce_type: str
Expand Down
Loading

0 comments on commit 3390028

Please sign in to comment.