summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <me@m-sp.org>2024-02-23 11:55:24 +0100
committerGitHub <noreply@github.com>2024-02-23 11:55:24 +0100
commit7bb08ee8260c825eb5af4824bc62f73155b4b592 (patch)
treea8641022448d0859e234bbdb5573f6022011f468
parent9dfb8430509619a4e9d36fd00a11b83a2d5d0c3c (diff)
[mlir][Transforms][NFC] Decouple `ConversionPatternRewriterImpl` from `ConversionPatternRewriter` (#82333)
`ConversionPatternRewriterImpl` no longer maintains a reference to the respective `ConversionPatternRewriter`. An `MLIRContext` is sufficient. This commit simplifies the internal state of `ConversionPatternRewriterImpl`.
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp44
1 files changed, 21 insertions, 23 deletions
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 508ee7416d55..d015bd529012 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -756,10 +756,9 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+ explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : rewriter(rewriter), eraseRewriter(rewriter.getContext()),
- config(config) {}
+ : eraseRewriter(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -854,8 +853,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
Type origOutputType,
const TypeConverter *converter);
- Value buildUnresolvedArgumentMaterialization(PatternRewriter &rewriter,
- Location loc, ValueRange inputs,
+ Value buildUnresolvedArgumentMaterialization(Block *block, Location loc,
+ ValueRange inputs,
Type origOutputType,
Type outputType,
const TypeConverter *converter);
@@ -934,8 +933,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- PatternRewriter &rewriter;
-
/// This rewriter must be used for erasing ops/blocks.
SingleEraseRewriter eraseRewriter;
@@ -1037,8 +1034,12 @@ void BlockTypeConversionRewrite::rollback() {
LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
function_ref<Operation *(Value)> findLiveUser) {
+ auto builder = OpBuilder::atBlockBegin(block, /*listener=*/&rewriterImpl);
+
// Process the remapping for each of the original arguments.
for (auto it : llvm::enumerate(origBlock->getArguments())) {
+ OpBuilder::InsertionGuard g(builder);
+
// If the type of this argument changed and the argument is still live, we
// need to materialize a conversion.
BlockArgument origArg = it.value();
@@ -1050,14 +1051,12 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
Value replacementValue = rewriterImpl.mapping.lookupOrDefault(origArg);
bool isDroppedArg = replacementValue == origArg;
- if (isDroppedArg)
- rewriterImpl.rewriter.setInsertionPointToStart(getBlock());
- else
- rewriterImpl.rewriter.setInsertionPointAfterValue(replacementValue);
+ if (!isDroppedArg)
+ builder.setInsertionPointAfterValue(replacementValue);
Value newArg;
if (converter) {
newArg = converter->materializeSourceConversion(
- rewriterImpl.rewriter, origArg.getLoc(), origArg.getType(),
+ builder, origArg.getLoc(), origArg.getType(),
isDroppedArg ? ValueRange() : ValueRange(replacementValue));
assert((!newArg || newArg.getType() == origArg.getType()) &&
"materialization hook did not provide a value of the expected "
@@ -1322,6 +1321,8 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
Block *ConversionPatternRewriterImpl::applySignatureConversion(
Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
+ MLIRContext *ctx = block->getParentOp()->getContext();
+
// If no arguments are being changed or added, there is nothing to do.
unsigned origArgCount = block->getNumArguments();
auto convertedTypes = signatureConversion.getConvertedTypes();
@@ -1338,7 +1339,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// Map all new arguments to the location of the argument they originate from.
SmallVector<Location> newLocs(convertedTypes.size(),
- rewriter.getUnknownLoc());
+ Builder(ctx).getUnknownLoc());
for (unsigned i = 0; i < origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap || inputMap->replacementValue)
@@ -1357,8 +1358,6 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
SmallVector<std::optional<ConvertedArgInfo>, 1> argInfo;
argInfo.resize(origArgCount);
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(newBlock);
for (unsigned i = 0; i != origArgCount; ++i) {
auto inputMap = signatureConversion.getInputMapping(i);
if (!inputMap)
@@ -1401,7 +1400,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
outputType = legalOutputType;
newArg = buildUnresolvedArgumentMaterialization(
- rewriter, origArg.getLoc(), replArgs, origOutputType, outputType,
+ newBlock, origArg.getLoc(), replArgs, origOutputType, outputType,
converter);
}
@@ -1439,12 +1438,11 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
return convertOp.getResult(0);
}
Value ConversionPatternRewriterImpl::buildUnresolvedArgumentMaterialization(
- PatternRewriter &rewriter, Location loc, ValueRange inputs,
- Type origOutputType, Type outputType, const TypeConverter *converter) {
- return buildUnresolvedMaterialization(
- MaterializationKind::Argument, rewriter.getInsertionBlock(),
- rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType,
- converter);
+ Block *block, Location loc, ValueRange inputs, Type origOutputType,
+ Type outputType, const TypeConverter *converter) {
+ return buildUnresolvedMaterialization(MaterializationKind::Argument, block,
+ block->begin(), loc, inputs, outputType,
+ origOutputType, converter);
}
Value ConversionPatternRewriterImpl::buildUnresolvedTargetMaterialization(
Location loc, Value input, Type outputType,
@@ -1556,7 +1554,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
setListener(impl.get());
}