summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Transforms/Mem2Reg.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Mem2Reg.cpp')
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp16
1 files changed, 15 insertions, 1 deletions
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 80e3b7901632..abe565ea862f 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -202,6 +202,7 @@ private:
/// Contains the reaching definition at this operation. Reaching definitions
/// are only computed for promotable memory operations with blocking uses.
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+ DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
@@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
+ replacedValuesMap[memOp] = stored;
}
}
}
@@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() {
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
llvm::SmallVector<Operation *> toErase;
+ // List of all replaced values in the slot.
+ llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
+ // Ops to visit with the `visitReplacedValues` method.
+ llvm::SmallVector<PromotableOpInterface> toVisit;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
@@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() {
slot, info.userToBlockingUses[toPromote], rewriter,
reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);
-
+ if (toPromoteMemOp.storesTo(slot))
+ if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
+ replacedValuesList.push_back({toPromoteMemOp, replacedValue});
continue;
}
@@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() {
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
+ if (toPromoteBasic.requiresReplacedValues())
+ toVisit.push_back(toPromoteBasic);
+ }
+ for (PromotableOpInterface op : toVisit) {
+ rewriter.setInsertionPointAfter(op);
+ op.visitReplacedValues(replacedValuesList, rewriter);
}
for (Operation *toEraseOp : toErase)