diff options
author | Matthias Springer <springerm@google.com> | 2024-03-12 02:56:46 +0000 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2024-03-12 02:58:00 +0000 |
commit | 71cdb0ef4d9962cc085a1f2182a0bcd1e0c025b5 (patch) | |
tree | d432a3043476022148de3843c51326016fc693e7 | |
parent | 15c5ef4723628eae7dbfc1f1738a69f641dd5cc8 (diff) |
[mlir][Transforms][NFC] Make `rewriterImpl` private in `IRRewrite`upstream/users/matthias-springer/rewriter_impl_private
This commit makes `rewriterImpl` private in `IRRewrite`. This ensures that only the conversion value mapping and the dialect conversion configuration can be accessed from an IR rewrite object.
-rw-r--r-- | mlir/lib/Transforms/Utils/DialectConversion.cpp | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e4a022b7a028..dbdfaeeeb28d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -232,6 +232,9 @@ protected: const ConversionConfig &getConfig() const; + ConversionValueMapping &getMapping(); + +private: const Kind kind; ConversionPatternRewriterImpl &rewriterImpl; }; @@ -470,7 +473,8 @@ public: /// live users, using the provided `findLiveUser` to search for a user that /// survives the conversion process. LogicalResult - materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser); + materializeLiveConversions(OpBuilder &builder, + function_ref<Operation *(Value)> findLiveUser); void commit(RewriterBase &rewriter) override; @@ -1035,6 +1039,8 @@ const ConversionConfig &IRRewrite::getConfig() const { return rewriterImpl.config; } +ConversionValueMapping &IRRewrite::getMapping() { return rewriterImpl.mapping; } + void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { // Inform the listener about all IR modifications that have already taken // place: References to the original block have been replaced with the new @@ -1049,8 +1055,7 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { llvm::zip_equal(origBlock->getArguments(), argInfo)) { // Handle the case of a 1->0 value mapping. if (!info) { - if (Value newArg = - rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + if (Value newArg = getMapping().lookupOrNull(origArg, origArg.getType())) rewriter.replaceAllUsesWith(origArg, newArg); continue; } @@ -1061,8 +1066,8 @@ void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) { // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) { - rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault( - castValue, origArg.getType())); + rewriter.replaceAllUsesWith( + origArg, getMapping().lookupOrDefault(castValue, origArg.getType())); } } } @@ -1072,23 +1077,23 @@ void BlockTypeConversionRewrite::rollback() { } LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( - function_ref<Operation *(Value)> findLiveUser) { + OpBuilder &builder, function_ref<Operation *(Value)> findLiveUser) { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(block); + // Process the remapping for each of the original arguments. for (auto it : llvm::enumerate(origBlock->getArguments())) { BlockArgument origArg = it.value(); - // Note: `block` may be detached, so OpBuilder::atBlockBegin cannot be used. - OpBuilder builder(it.value().getContext(), /*listener=*/&rewriterImpl); - builder.setInsertionPointToStart(block); // If the type of this argument changed and the argument is still live, we // need to materialize a conversion. - if (rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType())) + if (getMapping().lookupOrNull(origArg, origArg.getType())) continue; Operation *liveUser = findLiveUser(origArg); if (!liveUser) continue; - Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg); + Value replacementValue = getMapping().lookupOrDefault(origArg); bool isDroppedArg = replacementValue == origArg; if (!isDroppedArg) builder.setInsertionPointAfterValue(replacementValue); @@ -1113,13 +1118,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions( << "see existing live user here: " << *liveUser; return failure(); } - rewriterImpl.mapping.map(origArg, newArg); + getMapping().map(origArg, newArg); } return success(); } void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.mapping.lookupOrNull(value); + Value repl = getMapping().lookupOrNull(value); assert(repl && "expected that value is mapped"); if (isa<BlockArgument>(repl)) { @@ -1138,7 +1143,7 @@ void ReplaceAllUsesRewrite::commit(RewriterBase &rewriter) { }); } -void ReplaceAllUsesRewrite::rollback() { rewriterImpl.mapping.erase(value); } +void ReplaceAllUsesRewrite::rollback() { getMapping().erase(value); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>( @@ -1147,7 +1152,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { // Compute replacement values. SmallVector<Value> replacements = llvm::map_to_vector(op->getResults(), [&](OpResult result) { - return rewriterImpl.mapping.lookupOrNull(result, result.getType()); + return getMapping().lookupOrNull(result, result.getType()); }); // Notify the listener that the operation is about to be replaced. @@ -1179,7 +1184,7 @@ void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { void ReplaceOperationRewrite::rollback() { for (auto result : op->getResults()) - rewriterImpl.mapping.erase(result); + getMapping().erase(result); } void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) { @@ -1198,7 +1203,7 @@ void CreateOperationRewrite::rollback() { void UnresolvedMaterializationRewrite::rollback() { if (getMaterializationKind() == MaterializationKind::Target) { for (Value input : op->getOperands()) - rewriterImpl.mapping.erase(input); + getMapping().erase(input); } op->erase(); } @@ -2721,7 +2726,7 @@ LogicalResult OperationConverter::legalizeConvertedArgumentTypes( if (auto *blockTypeConversionRewrite = dyn_cast<BlockTypeConversionRewrite>(rewrite.get())) if (failed(blockTypeConversionRewrite->materializeLiveConversions( - findLiveUser))) + rewriter, findLiveUser))) return failure(); } return success(); |