diff options
author | Wang Pengcheng <wangpengcheng.pp@bytedance.com> | 2024-04-16 21:27:31 +0800 |
---|---|---|
committer | Wang Pengcheng <wangpengcheng.pp@bytedance.com> | 2024-04-16 21:27:31 +0800 |
commit | 36640769547bedf26ddf149132c1b75f9e088a21 (patch) | |
tree | ae7d210d8d7593f7bb672006e4a31f192a02b9c5 /mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | |
parent | d72e50aae48ffed5fb6c1a9ad6bfc47c5ca93230 (diff) | |
parent | e7fb49c24e4be4780ee4df9829980c5e8ddd511e (diff) |
Created using spr 1.3.6-beta.1
Diffstat (limited to 'mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp')
-rw-r--r-- | mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 7c477f2e1412..d8dd1c93722b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { + // It is not safe to cache constants across regions. + // New constants could potentially violate dominance requirements. + IndexPool localPool; + // Emit 'tensor.empty' op SmallVector<OpFoldResult> outputTensorShape; for (auto index : llvm::seq<int64_t>(0, rank)) { auto size = index == dim ? targetSize - : getOrFoldTensorDim(rewriter, loc, indexPool, + : getOrFoldTensorDim(rewriter, loc, localPool, operand, index); outputTensorShape.push_back(size); } @@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef<OpFoldResult> targetShape, ArrayRef<Value> masterOperands) { - size_t rank = operand.getType().cast<RankedTensorType>().getRank(); - assert(targetShape.size() == rank); - assert(masterOperands.size() == rank); + int64_t rank = operand.getType().cast<RankedTensorType>().getRank(); + assert((int64_t)targetShape.size() == rank); + assert((int64_t)masterOperands.size() == rank); for (auto index : llvm::seq<int64_t>(0, rank)) operand = broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, |