diff options
Diffstat (limited to 'mlir/lib/Transforms/Mem2Reg.cpp')
-rw-r--r-- | mlir/lib/Transforms/Mem2Reg.cpp | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index 80e3b7901632..abe565ea862f 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -202,6 +202,7 @@ private: /// Contains the reaching definition at this operation. Reaching definitions /// are only computed for promotable memory operations with blocking uses. DenseMap<PromotableMemOpInterface, Value> reachingDefs; + DenseMap<PromotableMemOpInterface, Value> replacedValuesMap; DominanceInfo &dominance; MemorySlotPromotionInfo info; const Mem2RegStatistics &statistics; @@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block, assert(stored && "a memory operation storing to a slot must provide a " "new definition of the slot"); reachingDef = stored; + replacedValuesMap[memOp] = stored; } } } @@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() { dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent()); llvm::SmallVector<Operation *> toErase; + // List of all replaced values in the slot. + llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList; + // Ops to visit with the `visitReplacedValues` method. + llvm::SmallVector<PromotableOpInterface> toVisit; for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) { if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) { Value reachingDef = reachingDefs.lookup(toPromoteMemOp); @@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() { slot, info.userToBlockingUses[toPromote], rewriter, reachingDef) == DeletionKind::Delete) toErase.push_back(toPromote); - + if (toPromoteMemOp.storesTo(slot)) + if (Value replacedValue = replacedValuesMap[toPromoteMemOp]) + replacedValuesList.push_back({toPromoteMemOp, replacedValue}); continue; } @@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() { if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote], rewriter) == DeletionKind::Delete) toErase.push_back(toPromote); + if (toPromoteBasic.requiresReplacedValues()) + toVisit.push_back(toPromoteBasic); + } + for (PromotableOpInterface op : toVisit) { + rewriter.setInsertionPointAfter(op); + op.visitReplacedValues(replacedValuesList, rewriter); } for (Operation *toEraseOp : toErase) |