diff options
author | Matthias Springer <springerm@google.com> | 2024-04-04 03:07:43 +0000 |
---|---|---|
committer | Matthias Springer <springerm@google.com> | 2024-04-04 03:07:45 +0000 |
commit | 4c819864c5edbdb8137451c0a0b6a97240a17008 (patch) | |
tree | 3617ed9d8ff7712b6c445d3629bdeb9c0d92d3d3 | |
parent | 9df19ce40281551bd348b262a131085cf98dadf5 (diff) |
[mlir][SCF] Add `scf.for` bufferization preprocessing passupstream/users/matthias-springer/scf_bufferization_preprocessing
Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors.
This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:
```
error: Yield operand #0 is not equivalent to the corresponding iter bbArg
```
4 files changed, 80 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h index 90b315e83a8c..6107219ea94a 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.h @@ -23,6 +23,9 @@ namespace mlir { /// Creates a pass that bufferizes the SCF dialect. std::unique_ptr<Pass> createSCFBufferizePass(); +/// Creates a pass that preprocesses SCF loop before One-Shot Bufferize. +std::unique_ptr<Pass> createSCFLoopBufferizationPreprocessingPass(); + /// Creates a pass that specializes for loop for unrolling and /// vectorization. std::unique_ptr<Pass> createForLoopSpecializationPass(); diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td index 350611ad8687..94d3e51a1c90 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/SCF/Transforms/Passes.td @@ -18,6 +18,27 @@ def SCFBufferize : Pass<"scf-bufferize"> { "memref::MemRefDialect"]; } +def SCFLoopBufferizationPreprocessing + : Pass<"scf-loop-bufferization-preprocessing"> { + let summary = "Preprocess loops before One-Shot Bufferize"; + + let description = [{ + Preprocess `scf.for` loops before running One-Shot Bufferize to support + loops where a yielded tensor is not equivalent to the respective iter_arg. + Such IR is currently not supported by One-Shot Bufferize. + + This pass inserts a `bufferization.materialize_in_destination` op for every + yielded tensor, such that the yielded value is guaranteed to materialize in + the future buffer of the iter_arg; this is done by copying the tensor + contents into the iter_arg buffer. Such memcpys are a no-op in case the + tensor contents already materialize in the iter_arg buffer. + }]; + + let constructor = "mlir::createSCFLoopBufferizationPreprocessingPass()"; + let dependentDialects = ["bufferization::BufferizationDialect", + "scf::SCFDialect"]; +} + // Note: Making these canonicalization patterns would require a dependency // of the SCF dialect on the Affine/Tensor/MemRef dialects or vice versa. def SCFForLoopCanonicalization diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 21c618ab633f..727c4fc7c639 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -17,6 +17,7 @@ namespace mlir { #define GEN_PASS_DEF_SCFBUFFERIZE +#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir @@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> { return signalPassFailure(); }; }; + +struct SCFLoopBufferizationPreprocessingPass + : public impl::SCFLoopBufferizationPreprocessingBase< + SCFLoopBufferizationPreprocessingPass> { + void runOnOperation() override { + OpBuilder builder(getOperation()->getContext()); + getOperation()->walk([&](scf::YieldOp yieldOp) { + builder.setInsertionPoint(yieldOp); + // TODO: Support scf.while. + auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()); + if (!forOp) + return WalkResult::skip(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = dyn_cast<TensorType>(operand.get().getType()); + if (!tensorType) + continue; + auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()]; + Value materialized = + builder + .create<bufferization::MaterializeInDestinationOp>( + yieldOp.getLoc(), tensorType, operand.get(), bbArg) + .getResult(); + operand.set(materialized); + } + return WalkResult::advance(); + }); + } +}; } // namespace std::unique_ptr<Pass> mlir::createSCFBufferizePass() { return std::make_unique<SCFBufferizePass>(); } + +std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() { + return std::make_unique<SCFLoopBufferizationPreprocessingPass>(); +} diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir new file mode 100644 index 000000000000..176611782450 --- /dev/null +++ b/mlir/test/Dialect/SCF/one-shot-bufferize-preprocessing.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -scf-loop-bufferization-preprocessing -one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" -canonicalize | FileCheck %s + +// CHECK-LABEL: func @conflict_in_loop( +// CHECK-SAME: %[[A:.*]]: memref<10xf32> +func.func @conflict_in_loop(%A: tensor<10xf32>, %f: f32, %idx: index, %lb: index, %ub: index, %step: index) -> f32 { + // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { + %r = scf.for %i = %lb to %ub step %step iter_args(%tA = %A) -> (tensor<10xf32>) { + // CHECK: %[[alloc:.*]] = memref.alloc() + // CHECK: memref.copy %[[A]], %[[alloc]] + // CHECK: memref.store %{{.*}}, %[[alloc]] + %0 = tensor.insert %f into %tA[%i] : tensor<10xf32> + // CHECK: %[[read:.*]] = memref.load %[[A]] + %read = tensor.extract %tA[%idx] : tensor<10xf32> + // CHECK: vector.print %[[read]] + vector.print %read : f32 + // CHECK: memref.copy %[[alloc]], %[[A]] + scf.yield %0 : tensor<10xf32> + } + + // CHECK: memref.load %[[A]] + %f0 = tensor.extract %r[%step] : tensor<10xf32> + return %f0 : f32 +} |