diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp')
-rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp | 36 |
1 files changed, 14 insertions, 22 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index fea2f659535b..7b4024b6861a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -101,38 +101,30 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, Block *afterBody = loop.getAfterBody(); scf::YieldOp afterTerm = loop.getYieldOp(); - auto argNumber = inductionVar.getArgNumber(); - auto afterTermIndArg = afterTerm.getResults()[argNumber]; + unsigned argNumber = inductionVar.getArgNumber(); + Value afterTermIndArg = afterTerm.getResults()[argNumber]; - auto inductionVarAfter = afterBody->getArgument(argNumber); - - Value step; + Value inductionVarAfter = afterBody->getArgument(argNumber); // Find suitable `addi` op inside `after` block, one of the args must be an // Induction var passed from `before` block and second arg must be defined // outside of the loop and will be considered step value. // TODO: Add `subi` support? - for (auto &use : inductionVarAfter.getUses()) { - auto owner = dyn_cast<arith::AddIOp>(use.getOwner()); - if (!owner) - continue; - - auto other = - (inductionVarAfter == owner.getLhs() ? owner.getRhs() : owner.getLhs()); - if (!dom.properlyDominates(other, loop)) - continue; - - if (afterTermIndArg != owner.getResult()) - continue; + auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>(); + if (!addOp) + return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); - step = other; - break; + Value step; + if (addOp.getLhs() == inductionVarAfter) { + step = addOp.getRhs(); + } else if (addOp.getRhs() == inductionVarAfter) { + step = addOp.getLhs(); } - if (!step) - return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op"); + if (!step || !dom.properlyDominates(step, loop)) + return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form"); - auto lb = loop.getInits()[argNumber]; + Value lb = loop.getInits()[argNumber]; assert(lb.getType().isIntOrIndex()); assert(lb.getType() == ub.getType()); |