summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp36
1 files changed, 14 insertions, 22 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
index fea2f659535b..7b4024b6861a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
Block *afterBody = loop.getAfterBody();
scf::YieldOp afterTerm = loop.getYieldOp();
- auto argNumber = inductionVar.getArgNumber();
- auto afterTermIndArg = afterTerm.getResults()[argNumber];
+ unsigned argNumber = inductionVar.getArgNumber();
+ Value afterTermIndArg = afterTerm.getResults()[argNumber];
- auto inductionVarAfter = afterBody->getArgument(argNumber);
-
- Value step;
+ Value inductionVarAfter = afterBody->getArgument(argNumber);
// Find suitable `addi` op inside `after` block, one of the args must be an
// Induction var passed from `before` block and second arg must be defined
// outside of the loop and will be considered step value.
// TODO: Add `subi` support?
- for (auto &use : inductionVarAfter.getUses()) {
- auto owner = dyn_cast<arith::AddIOp>(use.getOwner());
- if (!owner)
- continue;
-
- auto other =
- (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs());
- if (!dom.properlyDominates(other, loop))
- continue;
-
- if (afterTermIndArg != owner.getResult())
- continue;
+ auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
- step = other;
- break;
+ Value step;
+ if (addOp.getLhs() == inductionVarAfter) {
+ step = addOp.getRhs();
+ } else if (addOp.getRhs() == inductionVarAfter) {
+ step = addOp.getLhs();
}
- if (!step)
- return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+ if (!step || !dom.properlyDominates(step, loop))
+ return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
- auto lb = loop.getInits()[argNumber];
+ Value lb = loop.getInits()[argNumber];
assert(lb.getType().isIntOrIndex());
assert(lb.getType() == ub.getType());