diff options
author | Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> | 2024-02-17 14:17:57 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-17 12:17:57 -0800 |
commit | 11705afc19383dedfb06c3b708d6fe8c0729b807 (patch) | |
tree | e93e79b69a43922cc500d6447337834a79f433db | |
parent | 164055f897c049543bd3e15548f891f823dc18b4 (diff) |
[mlir][sparse] deallocate tmp coo buffer generated during stage-spars… (#82017)
…e-ops pass.
5 files changed, 35 insertions, 14 deletions
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h index ebbc522123a5..c0f31762ee07 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.h @@ -1,5 +1,4 @@ -//===- SparseTensorInterfaces.h - sparse tensor operations -//interfaces-------===// +//===- SparseTensorInterfaces.h - sparse tensor operations interfaces------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -20,7 +19,7 @@ class StageWithSortSparseOp; namespace detail { LogicalResult stageWithSortImpl(sparse_tensor::StageWithSortSparseOp op, - PatternRewriter &rewriter); + PatternRewriter &rewriter, Value &tmpBufs); } // namespace detail } // namespace sparse_tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td index 1379363ff75f..05eed0483f2c 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.td @@ -34,9 +34,10 @@ def StageWithSortSparseOpInterface : OpInterface<"StageWithSortSparseOp"> { /*desc=*/"Stage the operation, return the final result value after staging.", /*retTy=*/"::mlir::LogicalResult", /*methodName=*/"stageWithSort", - /*args=*/(ins "::mlir::PatternRewriter &":$rewriter), + /*args=*/(ins "::mlir::PatternRewriter &":$rewriter, + "Value &":$tmpBuf), /*methodBody=*/[{ - return detail::stageWithSortImpl($_op, rewriter); + return detail::stageWithSortImpl($_op, rewriter, tmpBuf); }]>, ]; } diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp index d33eb9d2877a..4f9988d48d77 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp @@ -16,9 +16,14 @@ using namespace mlir::sparse_tensor; #include "mlir/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp.inc" -LogicalResult -sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, - PatternRewriter &rewriter) { +/// Stage the operations into a sequence of simple operations as follow: +/// op -> unsorted_coo + +/// unsorted_coo -> sorted_coo + +/// sorted_coo -> dstTp. +/// +/// return `tmpBuf` if a intermediate memory is allocated. +LogicalResult sparse_tensor::detail::stageWithSortImpl( + StageWithSortSparseOp op, PatternRewriter &rewriter, Value &tmpBufs) { if (!op.needsExtraSort()) return failure(); @@ -44,9 +49,15 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op, rewriter.replaceOp(op, dstCOO); } else { // Need an extra conversion if the target type is not COO. - rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO); + auto c = rewriter.replaceOpWithNewOp<ConvertOp>(op, finalTp, dstCOO); + rewriter.setInsertionPointAfter(c); + // Informs the caller about the intermediate buffer we allocated. We can not + // create a bufferization::DeallocateTensorOp here because it would + // introduce cyclic dependency between the SparseTensorDialect and the + // BufferizationDialect. Besides, whether the buffer need to be deallocated + // by SparseTensorDialect or by BufferDeallocationPass is still TBD. + tmpBufs = dstCOO; } - // TODO: deallocate extra COOs, we should probably delegate it to buffer - // deallocation pass. + return success(); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp index 5875cd4f9fd9..5b4395cc31a4 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/StageSparseOperations.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" @@ -21,8 +22,16 @@ struct StageUnorderedSparseOps : public OpRewritePattern<StageWithSortOp> { LogicalResult matchAndRewrite(StageWithSortOp op, PatternRewriter &rewriter) const override { - return llvm::cast<StageWithSortSparseOp>(op.getOperation()) - .stageWithSort(rewriter); + Location loc = op.getLoc(); + Value tmpBuf = nullptr; + auto itOp = llvm::cast<StageWithSortSparseOp>(op.getOperation()); + LogicalResult stageResult = itOp.stageWithSort(rewriter, tmpBuf); + // Deallocate tmpBuf. + // TODO: Delegate to buffer deallocation pass in the future. + if (succeeded(stageResult) && tmpBuf) + rewriter.create<bufferization::DeallocTensorOp>(loc, tmpBuf); + + return stageResult; } }; } // namespace diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir index 96a1140372bd..83dbc9568c7a 100644 --- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir +++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir @@ -82,10 +82,11 @@ func.func @sparse_constant_csc() -> tensor<8x7xf32, #CSC>{ // CHECK: scf.if // CHECK: tensor.insert // CHECK: sparse_tensor.load -// CHECK: sparse_tensor.reorder_coo +// CHECK: %[[TMP:.*]] = sparse_tensor.reorder_coo // CHECK: sparse_tensor.foreach // CHECK: tensor.insert // CHECK: sparse_tensor.load +// CHECK: bufferization.dealloc_tensor %[[TMP]] func.func @sparse_convert_3d(%arg0: tensor<?x?x?xf64>) -> tensor<?x?x?xf64, #SparseTensor> { %0 = sparse_tensor.convert %arg0 : tensor<?x?x?xf64> to tensor<?x?x?xf64, #SparseTensor> return %0 : tensor<?x?x?xf64, #SparseTensor> |