diff options
author | Arseniy Obolenskiy <gooddoog@student.su> | 2024-01-05 22:49:21 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-05 17:49:21 +0300 |
commit | 59569eb756265b2a5d9d96f6c6c5ee1a3c371c4f (patch) | |
tree | 40be19fc2dfdc7a770866ad10f3791ed682063cf | |
parent | a0e6b7c0429204ac42095be09bd1d5dcad4a052a (diff) |
[mlir] Fix support for loop normalization with integer indices (#76566)
Choose correct type for updated loop boundaries after scf loop
normalization, do not force chosen type to IndexType
-rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 7 | ||||
-rw-r--r-- | mlir/test/Dialect/SCF/transform-ops.mlir | 30 |
2 files changed, 35 insertions, 2 deletions
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e85825595e3c..a2043c647d49 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -502,9 +502,12 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder, Value newLowerBound = isZeroBased ? lowerBound - : boundsBuilder.create<arith::ConstantIndexOp>(loc, 0); + : boundsBuilder.create<arith::ConstantOp>( + loc, boundsBuilder.getZeroAttr(lowerBound.getType())); Value newStep = - isStepOne ? step : boundsBuilder.create<arith::ConstantIndexOp>(loc, 1); + isStepOne ? step + : boundsBuilder.create<arith::ConstantOp>( + loc, boundsBuilder.getIntegerAttr(step.getType(), 1)); // Insert code computing the value of the original loop induction variable // from the "normalized" one. diff --git a/mlir/test/Dialect/SCF/transform-ops.mlir b/mlir/test/Dialect/SCF/transform-ops.mlir index 93ebf67f8b71..a945758143c6 100644 --- a/mlir/test/Dialect/SCF/transform-ops.mlir +++ b/mlir/test/Dialect/SCF/transform-ops.mlir @@ -270,3 +270,33 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK-LABEL: func @coalesce_i32_loops( + +// This test checks for loop coalescing success for non-index loop boundaries and step type +func.func @coalesce_i32_loops() { + %0 = arith.constant 0 : i32 + %1 = arith.constant 128 : i32 + %2 = arith.constant 2 : i32 + %3 = arith.constant 64 : i32 + // CHECK-DAG: %[[C0_I32:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[C1_I32:.*]] = arith.constant 1 : i32 + // CHECK: scf.for %[[ARG0:.*]] = %[[C0_I32]] to {{.*}} step %[[C1_I32]] : i32 + scf.for %i = %0 to %1 step %2 : i32 { + scf.for %j = %0 to %3 step %2 : i32 { + arith.addi %i, %j : i32 + } + } {coalesce} + return +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} attributes {coalesce} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.for"> + %2 = transform.loop.coalesce %1: (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) + transform.yield + } +} |