summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2024-04-04 03:07:43 +0000
committerMatthias Springer <springerm@google.com>2024-04-04 03:07:45 +0000
commit4c819864c5edbdb8137451c0a0b6a97240a17008 (patch)
tree3617ed9d8ff7712b6c445d3629bdeb9c0d92d3d3
parent9df19ce40281551bd348b262a131085cf98dadf5 (diff)
[mlir][SCF] Add `scf.for` bufferization preprocessing passupstream/users/matthias-springer/scf_bufferization_preprocessing
Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors. This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message: ``` error: Yield operand #0 is not equivalent to the corresponding iter bbArg ```
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms/Passes.h3
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms/Passes.td21
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp33
-rw-r--r--mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir23
4 files changed, 80 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
index 90b315e83a8c..6107219ea94a 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h
@@ -23,6 +23,9 @@ namespace mlir {
/// Creates a pass that bufferizes the SCF dialect.
std::unique_ptr<Pass> createSCFBufferizePass();
+/// Creates a pass that preprocesses SCF loop before One-Shot Bufferize.
+std::unique_ptr<Pass> createSCFLoopBufferizationPreprocessingPass();
+
/// Creates a pass that specializes for loop for unrolling and
/// vectorization.
std::unique_ptr<Pass> createForLoopSpecializationPass();
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
index 350611ad8687..94d3e51a1c90 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td
@@ -18,6 +18,27 @@ def SCFBufferize : Pass<"scf-bufferize"> {
"memref::MemRefDialect"];
}
+def SCFLoopBufferizationPreprocessing
+ : Pass<"scf-loop-bufferization-preprocessing"> {
+ let summary = "Preprocess loops before One-Shot Bufferize";
+
+ let description = [{
+ Preprocess `scf.for` loops before running One-Shot Bufferize to support
+ loops where a yielded tensor is not equivalent to the respective iter_arg.
+ Such IR is currently not supported by One-Shot Bufferize.
+
+ This pass inserts a `bufferization.materialize_in_destination` op for every
+ yielded tensor, such that the yielded value is guaranteed to materialize in
+ the future buffer of the iter_arg; this is done by copying the tensor
+ contents into the iter_arg buffer. Such memcpys are a no-op in case the
+ tensor contents already materialize in the iter_arg buffer.
+ }];
+
+ let constructor = "mlir::createSCFLoopBufferizationPreprocessingPass()";
+ let dependentDialects = ["bufferization::BufferizationDialect",
+ "scf::SCFDialect"];
+}
+
// Note: Making these canonicalization patterns would require a dependency
// of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa.
def SCFForLoopCanonicalization
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 21c618ab633f..727c4fc7c639 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -17,6 +17,7 @@
namespace mlir {
#define GEN_PASS_DEF_SCFBUFFERIZE
+#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
@@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> {
return signalPassFailure();
};
};
+
+struct SCFLoopBufferizationPreprocessingPass
+ : public impl::SCFLoopBufferizationPreprocessingBase<
+ SCFLoopBufferizationPreprocessingPass> {
+ void runOnOperation() override {
+ OpBuilder builder(getOperation()->getContext());
+ getOperation()->walk([&](scf::YieldOp yieldOp) {
+ builder.setInsertionPoint(yieldOp);
+ // TODO: Support scf.while.
+ auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+ if (!forOp)
+ return WalkResult::skip();
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = dyn_cast<TensorType>(operand.get().getType());
+ if (!tensorType)
+ continue;
+ auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()];
+ Value materialized =
+ builder
+ .create<bufferization::MaterializeInDestinationOp>(
+ yieldOp.getLoc(), tensorType, operand.get(), bbArg)
+ .getResult();
+ operand.set(materialized);
+ }
+ return WalkResult::advance();
+ });
+ }
+};
} // namespace
std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
return std::make_unique<SCFBufferizePass>();
}
+
+std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() {
+ return std::make_unique<SCFLoopBufferizationPreprocessingPass>();
+}
diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
new file mode 100644
index 000000000000..176611782450
--- /dev/null
+++ b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -scf-loop-bufferization-preprocessing -one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" -canonicalize | FileCheck %s
+
+// CHECK-LABEL: func @conflict_in_loop(
+// CHECK-SAME: %[[A:.*]]: memref<10xf32>
+func.func @conflict_in_loop(%A: tensor<10xf32>, %f: f32, %idx: index, %lb: index, %ub: index, %step: index) -> f32 {
+ // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+ %r = scf.for %i = %lb to %ub step %step iter_args(%tA = %A) -> (tensor<10xf32>) {
+ // CHECK: %[[alloc:.*]] = memref.alloc()
+ // CHECK: memref.copy %[[A]], %[[alloc]]
+ // CHECK: memref.store %{{.*}}, %[[alloc]]
+ %0 = tensor.insert %f into %tA[%i] : tensor<10xf32>
+ // CHECK: %[[read:.*]] = memref.load %[[A]]
+ %read = tensor.extract %tA[%idx] : tensor<10xf32>
+ // CHECK: vector.print %[[read]]
+ vector.print %read : f32
+ // CHECK: memref.copy %[[alloc]], %[[A]]
+ scf.yield %0 : tensor<10xf32>
+ }
+
+ // CHECK: memref.load %[[A]]
+ %f0 = tensor.extract %r[%step] : tensor<10xf32>
+ return %f0 : f32
+}