summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2024-03-12 02:56:46 +0000
committerMatthias Springer <springerm@google.com>2024-03-12 02:58:00 +0000
commit71cdb0ef4d9962cc085a1f2182a0bcd1e0c025b5 (patch)
treed432a3043476022148de3843c51326016fc693e7
parent15c5ef4723628eae7dbfc1f1738a69f641dd5cc8 (diff)
[mlir][Transforms][NFC] Make `rewriterImpl` private in `IRRewrite`upstream/users/matthias-springer/rewriter_impl_private
This commit makes `rewriterImpl` private in `IRRewrite`. This ensures that only the conversion value mapping and the dialect conversion configuration can be accessed from an IR rewrite object.
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp41
1 files changed, 23 insertions, 18 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index e4a022b7a028..dbdfaeeeb28d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -232,6 +232,9 @@ protected:
const ConversionConfig &getConfig() const;
+ ConversionValueMapping &getMapping();
+
+private:
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -470,7 +473,8 @@ public:
/// live users, using the provided `findLiveUser` to search for a user that
/// survives the conversion process.
LogicalResult
- materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
+ materializeLiveConversions(OpBuilder &builder,
+ function_ref<Operation *(Value)> findLiveUser);
void commit(RewriterBase &rewriter) override;
@@ -1035,6 +1039,8 @@ const ConversionConfig &IRRewrite::getConfig() const {
return rewriterImpl.config;
}
+ConversionValueMapping &IRRewrite::getMapping() { return rewriterImpl.mapping; }
+
void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
// Inform the listener about all IR modifications that have already taken
// place: References to the original block have been replaced with the new
@@ -1049,8 +1055,7 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
llvm::zip_equal(origBlock->getArguments(), argInfo)) {
// Handle the case of a 1->0 value mapping.
if (!info) {
- if (Value newArg =
- rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ if (Value newArg = getMapping().lookupOrNull(origArg, origArg.getType()))
rewriter.replaceAllUsesWith(origArg, newArg);
continue;
}
@@ -1061,8 +1066,8 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
// If the argument is still used, replace it with the generated cast.
if (!origArg.use_empty()) {
- rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
- castValue, origArg.getType()));
+ rewriter.replaceAllUsesWith(
+ origArg, getMapping().lookupOrDefault(castValue, origArg.getType()));
}
}
}
@@ -1072,23 +1077,23 @@ void BlockTypeConversionRewrite::rollback() {
}
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
- function_ref<Operation *(Value)> findLiveUser) {
+ OpBuilder &builder, function_ref<Operation *(Value)> findLiveUser) {
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToStart(block);
+
// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
BlockArgument origArg = it.value();
- // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used.
- OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl);
- builder.setInsertionPointToStart(block);
// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
- if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
+ if (getMapping().lookupOrNull(origArg, origArg.getType()))
continue;
Operation *liveUser = findLiveUser(origArg);
if (!liveUser)
continue;
- Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
+ Value replacementValue = getMapping().lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
if (!isDroppedArg)
builder.setInsertionPointAfterValue(replacementValue);
@@ -1113,13 +1118,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
<< "see existing live user here: " << *liveUser;
return failure();
}
- rewriterImpl.mapping.map(origArg, newArg);
+ getMapping().map(origArg, newArg);
}
return success();
}
void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.mapping.lookupOrNull(value);
+ Value repl = getMapping().lookupOrNull(value);
assert(repl && "expected that value is mapped");
if (isa<BlockArgument>(repl)) {
@@ -1138,7 +1143,7 @@ void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) {
});
}
-void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); }
+void ReplaceAllUsesRewrite::rollback() { getMapping().erase(value); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
@@ -1147,7 +1152,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
// Compute replacement values.
SmallVector<Value> replacements =
llvm::map_to_vector(op->getResults(), [&](OpResult result) {
- return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+ return getMapping().lookupOrNull(result, result.getType());
});
// Notify the listener that the operation is about to be replaced.
@@ -1179,7 +1184,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
void ReplaceOperationRewrite::rollback() {
for (auto result : op->getResults())
- rewriterImpl.mapping.erase(result);
+ getMapping().erase(result);
}
void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
@@ -1198,7 +1203,7 @@ void CreateOperationRewrite::rollback() {
void UnresolvedMaterializationRewrite::rollback() {
if (getMaterializationKind() == MaterializationKind::Target) {
for (Value input : op->getOperands())
- rewriterImpl.mapping.erase(input);
+ getMapping().erase(input);
}
op->erase();
}
@@ -2721,7 +2726,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes(
if (auto *blockTypeConversionRewrite =
dyn_cast<BlockTypeConversionRewrite>(rewrite.get()))
if (failed(blockTypeConversionRewrite->materializeLiveConversions(
- findLiveUser)))
+ rewriter, findLiveUser)))
return failure();
}
return success();