summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2024-03-08 02:09:29 +0000
committerMatthias Springer <springerm@google.com>2024-03-10 03:15:12 +0000
commite2b6b753bad33cbd03b79d3b9b4c2f0cabfbab8d (patch)
tree2775527e6aca96a78f166ee86953a210fc88b652
parentefcbf902a224e5d0f186a36e99e6f2de56afa827 (diff)
[mlir][Transform] Specify mapping update rules for `apply_conversion_patterns`upstream/users/matthias-springer/conversion_tracking_listener
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h51
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td11
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp46
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp142
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir22
-rw-r--r--mlir/test/Dialect/Transform/test-pattern-application.mlir39
6 files changed, 256 insertions, 55 deletions
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
index 32724ff4b98e..5db1a2c28fd4 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h
@@ -1026,7 +1026,7 @@ protected:
/// Return the transform op in which this TrackingListener is used.
TransformOpInterface getTransformOp() const { return transformOp; }
-private:
+protected:
friend class TransformRewriter;
void notifyOperationErased(Operation *op) override;
@@ -1034,6 +1034,7 @@ private:
void notifyOperationReplaced(Operation *op, ValueRange newValues) override;
using Listener::notifyOperationReplaced;
+private:
/// The transform op in which this TrackingListener is used.
TransformOpInterface transformOp;
@@ -1047,23 +1048,48 @@ private:
/// A specialized listener that keeps track of cases in which no replacement
/// payload could be found. The error state of this listener must be checked
/// before the end of its lifetime.
-class ErrorCheckingTrackingListener : public TrackingListener {
+template <typename TrackingListenerTy>
+class ErrorCheckingTrackingListener : public TrackingListenerTy {
public:
- using transform::TrackingListener::TrackingListener;
+ using TrackingListenerTy::TrackingListenerTy;
- ~ErrorCheckingTrackingListener() override;
+ ~ErrorCheckingTrackingListener() override {
+ // The state of the ErrorCheckingTrackingListener must be checked and reset
+ // if there was an error. This is to prevent errors from accidentally being
+ // missed.
+ assert(status.succeeded() && "listener state was not checked");
+ }
/// Check and return the current error state of this listener. Afterwards,
/// resets the error state to "success".
- DiagnosedSilenceableFailure checkAndResetError();
+ DiagnosedSilenceableFailure checkAndResetError() {
+ DiagnosedSilenceableFailure s = std::move(status);
+ status = DiagnosedSilenceableFailure::success();
+ errorCounter = 0;
+ return s;
+ }
/// Return "true" if this tracking listener had a failure.
- bool failed() const;
+ bool failed() const { return !status.succeeded(); }
protected:
- void
- notifyPayloadReplacementNotFound(Operation *op, ValueRange values,
- DiagnosedSilenceableFailure &&diag) override;
+ void notifyPayloadReplacementNotFound(
+ Operation *op, ValueRange values,
+ DiagnosedSilenceableFailure &&diag) override {
+ // Merge potentially existing diags and store the result in the listener.
+ SmallVector<Diagnostic> diags;
+ diag.takeDiagnostics(diags);
+ if (!status.succeeded())
+ status.takeDiagnostics(diags);
+ status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
+
+ // Report more details.
+ status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
+ for (auto &&[index, value] : llvm::enumerate(values))
+ status.attachNote(value.getLoc())
+ << "[" << errorCounter << "] replacement value " << index;
+ ++errorCounter;
+ }
private:
/// The error state of this listener. "Success" indicates that no error
@@ -1082,8 +1108,9 @@ protected:
friend class TransformState;
/// Create a new TransformRewriter.
- explicit TransformRewriter(MLIRContext *ctx,
- ErrorCheckingTrackingListener *listener);
+ explicit TransformRewriter(
+ MLIRContext *ctx,
+ ErrorCheckingTrackingListener<TrackingListener> *listener);
public:
/// Return "true" if the tracking listener had failures.
@@ -1106,7 +1133,7 @@ public:
Operation *replacement);
private:
- ErrorCheckingTrackingListener *const listener;
+ ErrorCheckingTrackingListener<TrackingListener> *const listener;
};
/// This trait is supposed to be attached to Transform dialect operations that
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 1766e4bb875f..686a51bf7f9d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -203,6 +203,16 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
lower ops to different ops (from a different dialect). More details can be
found at the documentation site of `TrackingListener`.
+ The way op handles are updated can be customized with `find_replacements`.
+ If `find_replacements` is set, replacement ops are *not* deduced from the
+ replacement SSA values. The `find_replacements` dictionary attribute
+ specifies the kind of op that should be considered as a replacement for a
+ replaced tracked op. E.g., "arith.mulf => llvm.fmul" specifies that the
+ replacement op for a tracked "arith.mulf" must be an "llvm.fmul" op that was
+ created in the same pattern that replaced the "arith.mulf" op. If there is
+ no such op or if there are multiple such ops, a tracking listener failure
+ is produced.
+
This transform produces a silenceable failure if the dialect conversion was
unsuccessful or the tracking listener failed to find a replacement op.
}];
@@ -212,6 +222,7 @@ def ApplyConversionPatternsOp : TransformDialectOp<"apply_conversion_patterns",
OptionalAttr<StrArrayAttr>:$illegal_ops,
OptionalAttr<StrArrayAttr>:$legal_dialects,
OptionalAttr<StrArrayAttr>:$illegal_dialects,
+ OptionalAttr<DictionaryAttr>:$find_replacements,
UnitAttr:$partial_conversion,
UnitAttr:$preserve_handles);
let results = (outs);
diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
index fe2eea535ffd..92f59c47018f 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp
@@ -935,8 +935,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
}
return true;
};
- transform::ErrorCheckingTrackingListener trackingListener(*this, transform,
- config);
+ transform::ErrorCheckingTrackingListener<transform::TrackingListener>
+ trackingListener(*this, transform, config);
transform::TransformRewriter rewriter(transform->getContext(),
&trackingListener);
@@ -1214,11 +1214,10 @@ DiagnosedSilenceableFailure transform::TrackingListener::findReplacementOp(
Operation *&result, Operation *op, ValueRange newValues) const {
assert(op->getNumResults() == newValues.size() &&
"invalid number of replacement values");
- SmallVector<Value> values(newValues.begin(), newValues.end());
-
DiagnosedSilenceableFailure diag = emitSilenceableFailure(
getTransformOp(), "tracking listener failed to find replacement op "
"during application of this transform op");
+ SmallVector<Value> values(newValues.begin(), newValues.end());
do {
// If the replacement values belong to different ops, drop the mapping.
@@ -1349,49 +1348,12 @@ void transform::TrackingListener::notifyOperationReplaced(
(void)replacePayloadOp(op, replacement);
}
-transform::ErrorCheckingTrackingListener::~ErrorCheckingTrackingListener() {
- // The state of the ErrorCheckingTrackingListener must be checked and reset
- // if there was an error. This is to prevent errors from accidentally being
- // missed.
- assert(status.succeeded() && "listener state was not checked");
-}
-
-DiagnosedSilenceableFailure
-transform::ErrorCheckingTrackingListener::checkAndResetError() {
- DiagnosedSilenceableFailure s = std::move(status);
- status = DiagnosedSilenceableFailure::success();
- errorCounter = 0;
- return s;
-}
-
-bool transform::ErrorCheckingTrackingListener::failed() const {
- return !status.succeeded();
-}
-
-void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound(
- Operation *op, ValueRange values, DiagnosedSilenceableFailure &&diag) {
-
- // Merge potentially existing diags and store the result in the listener.
- SmallVector<Diagnostic> diags;
- diag.takeDiagnostics(diags);
- if (!status.succeeded())
- status.takeDiagnostics(diags);
- status = DiagnosedSilenceableFailure::silenceableFailure(std::move(diags));
-
- // Report more details.
- status.attachNote(op->getLoc()) << "[" << errorCounter << "] replaced op";
- for (auto &&[index, value] : llvm::enumerate(values))
- status.attachNote(value.getLoc())
- << "[" << errorCounter << "] replacement value " << index;
- ++errorCounter;
-}
-
//===----------------------------------------------------------------------===//
// TransformRewriter
//===----------------------------------------------------------------------===//
transform::TransformRewriter::TransformRewriter(
- MLIRContext *ctx, ErrorCheckingTrackingListener *listener)
+ MLIRContext *ctx, ErrorCheckingTrackingListener<TrackingListener> *listener)
: RewriterBase(ctx), listener(listener) {
setListener(listener);
}
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index ca80899ab073..b73fceee7aba 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -493,6 +493,125 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns(
// ApplyConversionPatternsOp
//===----------------------------------------------------------------------===//
+namespace {
+/// A specialized tracking listener for dialect conversions. It can be
+/// configured with a "replacement mapping" that specifies how replacement ops
+/// for replaced tracked operations should be determined.
+class ConversionTrackingListener : public transform::TrackingListener {
+public:
+ ConversionTrackingListener(
+ transform::TransformState &state, transform::TransformOpInterface op,
+ transform::TrackingListenerConfig config,
+ const DenseMap<StringRef, StringRef> *replacementMapping)
+ : transform::TrackingListener(state, op, config),
+ replacementMapping(replacementMapping) {}
+
+ /// Instead of deducing the replacement op from the replacement values, the
+ /// replacement op is chosen among all ops that were created during the
+ /// current pattern application. E.g., a mapping of "arith.mulsi_extended ->
+ /// llvm.mul" indicates that tracked arith.mulsi_extended ops should be
+ /// updated to llvm.mul ops, assuming that an llvm.mul op was created in the
+ /// same pattern that replaced the arith.mulsi_extended op. If no such op or
+ /// multiple such ops were created, "nullptr" replacement op is returned.
+ ///
+ /// If no replacement mapping is set, fall back to the original mechanism of
+ /// `TrackingListener`.
+ DiagnosedSilenceableFailure
+ findReplacementOp(Operation *&result, Operation *op,
+ ValueRange newValues) const override;
+
+protected:
+ void notifyOperationErased(Operation *op) override;
+
+ void notifyOperationInserted(Operation *op,
+ OpBuilder::InsertPoint previous) override;
+
+ void notifyPatternBegin(const Pattern &pattern, Operation *op) override;
+
+ void notifyPatternEnd(const Pattern &pattern, LogicalResult status) override;
+
+ /// The root op of the pattern that is currently being applied or "nullptr" if
+ /// no pattern application is running.
+ Operation *rootOp = nullptr;
+
+ /// All ops that have been created during the current pattern application.
+ /// This set is maintained only if "config.replacementMapping" is set.
+ SmallVector<Operation *> createdOps;
+
+ /// A mapping that specifies how replacement ops should be
+ /// determined when a mapped op is replaced. If set to "nullptr", the default
+ /// lookup mechanism (i.e., op deduced from the replacement values) is used.
+ const DenseMap<StringRef, StringRef> *replacementMapping = nullptr;
+};
+} // namespace
+
+void ConversionTrackingListener::notifyOperationErased(Operation *op) {
+ TrackingListener::notifyOperationErased(op);
+
+ // Remove from created ops.
+ auto it = llvm::find(createdOps, op);
+ if (it != createdOps.end())
+ createdOps.erase(it);
+}
+
+void ConversionTrackingListener::notifyOperationInserted(
+ Operation *op, OpBuilder::InsertPoint previous) {
+ if (replacementMapping)
+ createdOps.push_back(op);
+}
+
+void ConversionTrackingListener::notifyPatternBegin(const Pattern &pattern,
+ Operation *op) {
+ assert(!rootOp && "expected that no other pattern is in progress");
+ rootOp = op;
+}
+
+void ConversionTrackingListener::notifyPatternEnd(const Pattern &pattern,
+ LogicalResult status) {
+ rootOp = nullptr;
+ createdOps.clear();
+}
+
+DiagnosedSilenceableFailure
+ConversionTrackingListener::findReplacementOp(Operation *&result, Operation *op,
+ ValueRange newValues) const {
+ if (!replacementMapping)
+ return TrackingListener::findReplacementOp(result, op, newValues);
+
+ DiagnosedSilenceableFailure diag = emitSilenceableFailure(
+ getTransformOp(),
+ "conversion tracking listener failed to find replacement op during "
+ "application of this transform op");
+
+ auto it = replacementMapping->find(op->getName().getStringRef());
+ if (it == replacementMapping->end()) {
+ diag.attachNote(op->getLoc())
+ << "no mapping specified for '" << op->getName().getStringRef() << "'";
+ return diag;
+ }
+ StringRef replacementOpName = it->second;
+ Operation *replacementOp = nullptr;
+ for (Operation *op : createdOps) {
+ if (op->getName().getStringRef() == replacementOpName) {
+ if (replacementOp) {
+ diag.attachNote(op->getLoc()) << "multiple '" << replacementOpName
+ << "' replacement candidates found for '"
+ << op->getName().getStringRef() << "'";
+ return diag;
+ }
+ replacementOp = op;
+ }
+ }
+ if (!replacementOp) {
+ diag.attachNote(op->getLoc())
+ << "no replacement found for '" << op->getName().getStringRef()
+ << "', expected '" << replacementOpName << "'";
+ return diag;
+ }
+ result = replacementOp;
+ return DiagnosedSilenceableFailure::success();
+}
+
DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
@@ -523,6 +642,15 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
for (Attribute attr : cast<ArrayAttr>(*getIllegalDialects()))
conversionTarget.addIllegalDialect(cast<StringAttr>(attr).getValue());
+ // Extract op replacement rules from attribute.
+ DenseMap<StringRef, StringRef> replacementMapping;
+ if (getFindReplacements()) {
+ DictionaryAttr mappingAttr = cast<DictionaryAttr>(*getFindReplacements());
+ for (auto it : mappingAttr)
+ replacementMapping[it.getName()] =
+ cast<StringAttr>(it.getValue()).getValue();
+ }
+
// Gather all specified patterns.
RewritePatternSet patterns(ctx);
// Need to keep the converters alive until after pattern application because
@@ -569,7 +697,9 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply(
// name.
TrackingListenerConfig trackingConfig;
trackingConfig.requireMatchingReplacementOpName = false;
- ErrorCheckingTrackingListener trackingListener(state, *this, trackingConfig);
+ ErrorCheckingTrackingListener<ConversionTrackingListener> trackingListener(
+ state, *this, trackingConfig,
+ replacementMapping.empty() ? nullptr : &replacementMapping);
ConversionConfig conversionConfig;
if (getPreserveHandles())
conversionConfig.listener = &trackingListener;
@@ -658,6 +788,16 @@ LogicalResult transform::ApplyConversionPatternsOp::verify() {
}
}
}
+ if (getFindReplacements()) {
+ if (!getPreserveHandles())
+ return emitOpError() << "find_replacements requires preserve_handles";
+ auto mapping = cast<DictionaryAttr>(*getFindReplacements());
+ for (auto it : mapping) {
+ if (!isa<StringAttr>(it.getValue()))
+ return emitOpError() << "expected find_replacements to contain only "
+ "StringAttr values";
+ }
+ }
return success();
}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 73a5f36af929..729645aca2f9 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -771,3 +771,25 @@ module attributes { transform.with_named_sequence } {
transform.yield %arg0 : !transform.any_op
}
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected find_replacements to contain only StringAttr values}}
+ transform.apply_conversion_patterns to %arg0 {
+ } {legal_dialects = ["func", "llvm"], preserve_handles,
+ find_replacements = {"arith.muli" = 3}} : !transform.any_op
+ transform.yield
+}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{find_replacements requires preserve_handles}}
+ transform.apply_conversion_patterns to %arg0 {
+ } {legal_dialects = ["func", "llvm"],
+ find_replacements = {"arith.muli" = 3}} : !transform.any_op
+ transform.yield
+}
diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir
index fa8a555af921..7ac2838bb95c 100644
--- a/mlir/test/Dialect/Transform/test-pattern-application.mlir
+++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir
@@ -447,3 +447,42 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// "arith.mulsi_extended" is tracked and replaced with "llvm.mul" (and other
+// ops) during a dialect conversion. Make sure that the handle is updated
+// accordingly.
+
+// CHECK-LABEL: func @dialect_conversion_find_replacements(
+// CHECK-SAME: %[[arg0:.*]]: vector<4xi32>, %[[arg1:.*]]: vector<4xi32>)
+// CHECK: %[[VAL0:.*]] = llvm.sext %[[arg0]] : vector<4xi32> to vector<4xi64>
+// CHECK: %[[VAL1:.*]] = llvm.sext %[[arg1]] : vector<4xi32> to vector<4xi64>
+// CHECK: %[[VAL2:.*]] = llvm.mul %[[VAL0]], %[[VAL1]] {annotated} : vector<4xi64>
+// CHECK: %[[VAL3:.*]] = llvm.trunc %[[VAL2]] : vector<4xi64> to vector<4xi32>
+// CHECK: %[[VAL4:.*]] = llvm.mlir.constant(dense<32> : vector<4xi64>) : vector<4xi64>
+// CHECK: %[[VAL5:.*]] = llvm.lshr %[[VAL2]], %[[VAL4]] : vector<4xi64>
+// CHECK: %[[VAL6:.*]] = llvm.trunc %[[VAL5]] : vector<4xi64> to vector<4xi32>
+// CHECK: return %[[VAL3]], %[[VAL6]] : vector<4xi32>, vector<4xi32>
+func.func @dialect_conversion_find_replacements(%arg0: vector<4xi32>, %arg1: vector<4xi32>) -> (vector<4xi32>, vector<4xi32>) {
+ %c:2 = arith.mulsi_extended %arg0, %arg1 : vector<4xi32>
+ return %c#0, %c#1 : vector<4xi32>, vector<4xi32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op) {
+ %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["arith.mulsi_extended"]} in %0 : (!transform.any_op) -> !transform.any_op
+ // arith.mulsi_extended handles are updated to llvm.mul.
+ transform.apply_conversion_patterns to %0 {
+ transform.apply_conversion_patterns.dialect_to_llvm "arith"
+ } with type_converter {
+ transform.apply_conversion_patterns.memref.memref_to_llvm_type_converter
+ } {legal_dialects = ["func", "llvm"], preserve_handles,
+ find_replacements = {"arith.mulsi_extended" = "llvm.mul"}}
+ : !transform.any_op
+ // Add an attribute to %1, which is now mapped to a new op.
+ transform.annotate %1 "annotated" : !transform.any_op
+ transform.yield
+ }
+}