Skip to content

Commit

Permalink
Clean the redundant constant code.
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis committed Dec 7, 2024
1 parent 08339cb commit f4f26db
Showing 1 changed file with 41 additions and 49 deletions.
90 changes: 41 additions & 49 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5712,12 +5712,10 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Location loc = op.getLoc();
auto si64Type = rewriter.getIntegerType(64, true);
Value si64Dtype = getDtypeIntValueForType(rewriter, loc, si64Type);
// helper for making int constants
std::function<Value(int64_t)> c = [&](int64_t val) {
Value newIntConstant =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(val));
return newIntConstant;
};
auto constantZero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto constantOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
std::function<Value(Value)> makeOneElementList = [&](Value element) {
auto listType = Torch::ListType::get(element.getType());
return rewriter.create<PrimListConstructOp>(loc, listType,
Expand All @@ -5732,10 +5730,12 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
// input_shape_tensor = torch.tensor(original_shape, device=t.device)
auto shapeType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank}, si64Type);
// %2 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[?],i1> ->
// !torch.vtensor<[1],si64>
Value inputShapeTensor =
rewriter.create<Torch::Aten_ShapeAsTensorOp>(loc, shapeType, input);

// t_flat = t.flatten()
// t_flat = t.flatten() # torch.flatten(t, 0, 0)
int64_t flattenedSize = 1;
if (inputType.hasSizes()) {
for (auto size : inputType.getSizes()) {
Expand All @@ -5749,19 +5749,19 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
auto flattenedInputType = rewriter.getType<Torch::ValueTensorType>(
flattendInputShape, inputType.getOptionalDtype());

Value inputDimsStart =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
Value inputDimsEnd = rewriter.create<ConstantIntOp>(
// %1 = torch.aten.flatten.using_ints %arg0, %int0, %int0_0 :
// !torch.vtensor<[?],i1>, !torch.int, !torch.int -> !torch.vtensor<[?],i1>
auto inputDimsEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank - 1));

Value flattenedInput = rewriter.create<AtenFlattenUsingIntsOp>(
loc, flattenedInputType, input, inputDimsStart, inputDimsEnd);
loc, flattenedInputType, input, constantZero /*inputDimsStart*/,
inputDimsEnd /*inputDimsEnd*/);

// nonzero_mask = (t_flat != 0)
auto boolMaskType = inputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), rewriter.getI1Type());
Value boolMask = rewriter.create<AtenNeScalarOp>(loc, boolMaskType,
flattenedInput, c(0));
Value boolMask = rewriter.create<AtenNeScalarOp>(
loc, boolMaskType, flattenedInput, constantZero);

// nonzero_mask = nonzero_mask.int()
Value falseCst = rewriter.create<ConstantBoolOp>(loc, false);
Expand All @@ -5775,25 +5775,22 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
auto cumulativeSumType =
dyn_cast<BaseTensorType>(flattenedInputType.getWithSizesAndDtype(
flattenedInputType.getOptionalSizes(), si64Type));
Value cumulativeSum = rewriter.create<AtenCumsumOp>(loc, cumulativeSumType,
intMask, c(0), noneCst);
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value subtracted = rewriter.create<AtenSubScalarOp>(
loc, cumulativeSumType, cumulativeSum, one, /*alpha=*/one);
Value cumulativeSum = rewriter.create<AtenCumsumOp>(
loc, cumulativeSumType, intMask, constantZero, noneCst);
Value subtracted =
rewriter.create<AtenSubScalarOp>(loc, cumulativeSumType, cumulativeSum,
constantOne, /*alpha=*/constantOne);

// destination_indices = torch.clamp(destination_indices, min=0)
Value indices = rewriter.create<AtenClampMinOp>(loc, cumulativeSumType,
subtracted, c(0));
subtracted, constantZero);

// iota = torch.arange(len(t_flat), device=t.device) * nonzero_mask
Value constZero = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value end =
rewriter.create<AtenSizeIntOp>(loc, flattenedInput, /*dim=*/constZero);
Value end = rewriter.create<AtenSizeIntOp>(loc, flattenedInput,
/*dim=*/constantZero);
Value rangeTensor = rewriter.create<AtenArangeStartStepOp>(
loc, cumulativeSumType, c(0), end, one, noneCst, noneCst, noneCst,
noneCst);
loc, cumulativeSumType, constantZero, end, constantOne, noneCst,
noneCst, noneCst, noneCst);
Value multiplied = rewriter.create<AtenMulTensorOp>(loc, cumulativeSumType,
rangeTensor, intMask);

Expand All @@ -5810,14 +5807,9 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
// index=destination_indices,
// src=iota, reduce='add')
Value reduceStr = rewriter.create<ConstantStrOp>(loc, "sum");
Value constAxis = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));

Value cstFalse = rewriter.create<ConstantBoolOp>(loc, false);
Value scatteredTensor = rewriter.create<AtenScatterReduceTwoOp>(
loc, cumulativeSumType, zerosTensor, /*axis=*/constAxis,
/*dims=*/indices, /*src=*/multiplied, reduceStr, cstFalse);
loc, cumulativeSumType, zerosTensor, /*axis=*/constantZero,
/*dims=*/indices, /*src=*/multiplied, reduceStr, falseCst);

// result_flat = compacted[:torch.sum(nonzero_mask)]
auto scalarType = ValueTensorType::get(rewriter.getContext(),
Expand All @@ -5831,52 +5823,52 @@ class DecomposeAtenNonzeroOp : public OpRewritePattern<AtenNonzeroOp> {
Value slicedResult =
rewriter.create<AtenSliceTensorOp>(loc, slicedResultType,
/*self=*/scatteredTensor,
/*dim=*/c(0),
/*start=*/c(0),
/*dim=*/constantZero,
/*start=*/constantZero,
/*end=*/numNonzero,
/*step=*/one);
/*step=*/constantOne);

// strides = torch.cumprod(torch.flip(input_shape_tensor, [0]), 0).flip(0)
Value flippedShape = rewriter.create<AtenFlipOp>(
loc, shapeType, inputShapeTensor, makeOneElementList(c(0)));
loc, shapeType, inputShapeTensor, makeOneElementList(constantZero));
Value cumulativeProduct = rewriter.create<AtenCumprodOp>(
loc, shapeType, flippedShape, c(0), noneCst);
loc, shapeType, flippedShape, constantZero, noneCst);
Value flippedCumulativeProduct = rewriter.create<AtenFlipOp>(
loc, shapeType, cumulativeProduct, makeOneElementList(c(0)));
loc, shapeType, cumulativeProduct, makeOneElementList(constantZero));
// strides = torch.cat([strides[1:], torch.tensor([1], device=t.device)])
auto oneTensorType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1}, si64Type);
Value oneTensor = rewriter.create<AtenScalarTensorOp>(
loc, oneTensorType, c(1), si64Dtype, noneCst, noneCst, noneCst);
loc, oneTensorType, constantOne, si64Dtype, noneCst, noneCst, noneCst);

auto slicedStrideType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{inputRank - 1}, // sizes
si64Type);
Value strideSliceStart = c(1);
Value strideSliceEnd = c(inputRank);
Value strideSliceEnd = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputRank));
Value slicedStrides = rewriter.create<AtenSliceTensorOp>(
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ c(0),
/*start=*/strideSliceStart, /*end=*/strideSliceEnd, /*step=*/c(1));
loc, slicedStrideType, flippedCumulativeProduct, /*dim*/ constantZero,
/*start=*/constantOne, /*end=*/strideSliceEnd, /*step=*/constantOne);

auto tensorListElementType = Torch::ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize}, si64Type);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(tensorListElementType),
SmallVector<Value>{slicedStrides, oneTensor});
Value strides =
rewriter.create<Torch::AtenCatOp>(loc, shapeType, tensorList, c(0));
Value strides = rewriter.create<Torch::AtenCatOp>(loc, shapeType,
tensorList, constantZero);

// multi_indices = (result_flat.unsqueeze(1) // strides.unsqueeze(0)) %
// input_shape_tensor
auto unsqueezedResultType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, 1}, si64Type);
Value unsqueezedResult = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedResultType, slicedResult, c(1));
loc, unsqueezedResultType, slicedResult, constantOne);

auto unsqueezedStridesType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{1, inputRank}, si64Type);
Value unsqueezedStrides = rewriter.create<AtenUnsqueezeOp>(
loc, unsqueezedStridesType, strides, c(0));
loc, unsqueezedStridesType, strides, constantZero);

auto dividedBroadcastType = ValueTensorType::get(
rewriter.getContext(), SmallVector<int64_t>{kUnknownSize, inputRank},
Expand Down

0 comments on commit f4f26db

Please sign in to comment.