diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d26601c0de8d6..cfffcee156c3e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -851,11 +851,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( float alpha, gamma; Value operand; if (binder.tensorOperand(operand) || - binder.f32FloatAttr(alpha, "alpha") || - binder.f32FloatAttr(gamma, "gamma") || + binder.f32FloatAttr(alpha, "alpha", 1.67326) || + binder.f32FloatAttr(gamma, "gamma", 1.0507) || binder.tensorResultType(resultType)) return failure(); + Torch::ValueTensorType inputType = + operand.getType().cast(); + Value vAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); @@ -864,12 +867,31 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); - Value vInputScale = rewriter.create( + Value cstOne = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, vAlpha, vScale, vInputScale); + Value cstNone = rewriter.create(binder.getLoc()); + Value zeroTensor = rewriter.create( + binder.getLoc(), resultType, operand, cstNone, cstNone, cstNone, + cstNone, cstNone); + Value exp = rewriter.create(binder.getLoc(), + resultType, operand); + Value expMulAlpha = rewriter.create( + binder.getLoc(), resultType, exp, vAlpha); + Value expMulAlphaSubAlpha = rewriter.create( + binder.getLoc(), resultType, expMulAlpha, vAlpha, cstOne); + Value neg = rewriter.create( + binder.getLoc(), resultType, expMulAlphaSubAlpha, vScale); + Value pos = rewriter.create( + binder.getLoc(), resultType, operand, vScale); + Type compareType = inputType.getWithSizesAndDtype( + inputType.getOptionalSizes(), rewriter.getI1Type()); + Value xLessThanZero = rewriter.create( + binder.getLoc(), compareType, operand, zeroTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, xLessThanZero, neg, pos); return success(); }); patterns.onOp("ReduceL1", 1, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e768033ac87f9..2cea874f341f9 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -140,7 +140,7 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; Type inputDtype = inputTensorType.getOptionalDtype(); - if (!inputDtype || !inputDtype.isInteger(64)) + if (!inputDtype || !(inputDtype.isInteger(64) || inputDtype.isInteger(1))) return nullptr; std::optional inputRank = getTensorRank(input); @@ -148,11 +148,19 @@ static Value getScalarIntValue(Value input, Location loc, return nullptr; if (auto valueTensorLiteralOp = input.getDefiningOp()) { - auto val = valueTensorLiteralOp.getValue() - .cast() - .getSplatValue(); - return rewriter.create( - loc, rewriter.getI64IntegerAttr(val)); + if (inputDtype.isInteger(64)) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } else { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + return rewriter.create( + loc, rewriter.getI64IntegerAttr(val)); + } } else if (auto primNumToTensorScalarOp = input.getDefiningOp()) { return primNumToTensorScalarOp.getA(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e426e998ebe04..f34b5a5e3b2cf 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2154,7 +2154,6 @@ "ElementwiseAtenFloorDivideTensorNegativeModule_basic", "ElementwiseLog10IntModule_basic", "ElementwiseLog2IntModule_basic", - "ElementwiseSeluModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", @@ -2669,8 +2668,6 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic",