summaryrefslogtreecommitdiffstats
path: root/mlir/lib/Conversion/TosaToLinalg
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/TosaToLinalg')
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp12
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp88
2 files changed, 72 insertions, 28 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 7c477f2e1412..d8dd1c93722b 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc,
// Emit 'then' region of 'scf.if'
auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) {
+ // It is not safe to cache constants across regions.
+ // New constants could potentially violate dominance requirements.
+ IndexPool localPool;
+
// Emit 'tensor.empty' op
SmallVector<OpFoldResult> outputTensorShape;
for (auto index : llvm::seq<int64_t>(0, rank)) {
auto size = index == dim ? targetSize
- : getOrFoldTensorDim(rewriter, loc, indexPool,
+ : getOrFoldTensorDim(rewriter, loc, localPool,
operand, index);
outputTensorShape.push_back(size);
}
@@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
IndexPool &indexPool, Value operand,
ArrayRef<OpFoldResult> targetShape,
ArrayRef<Value> masterOperands) {
- size_t rank = operand.getType().cast<RankedTensorType>().getRank();
- assert(targetShape.size() == rank);
- assert(masterOperands.size() == rank);
+ int64_t rank = operand.getType().cast<RankedTensorType>().getRank();
+ assert((int64_t)targetShape.size() == rank);
+ assert((int64_t)masterOperands.size() == rank);
for (auto index : llvm::seq<int64_t>(0, rank))
operand =
broadcastDynamicDimension(rewriter, loc, indexPool, operand, index,
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3f39cbf03a9a..8fb8d1648656 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -26,6 +26,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Interfaces/InferTypeOpInterface.h"
+
#include <numeric>
#include <type_traits>
@@ -34,7 +36,7 @@ using namespace mlir::tosa;
static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
TypedAttr padAttr, OpBuilder &rewriter) {
- // Input should be padded if necessary.
+ // Input should be padded only if necessary.
if (llvm::all_of(pad, [](int64_t p) { return p == 0; }))
return input;
@@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad,
SmallVector<int64_t, 4> paddedShape;
SmallVector<OpFoldResult, 8> lowIndices;
SmallVector<OpFoldResult, 8> highIndices;
- for (int i = 0, s = inputShape.size(); i < s; i++) {
+ for (size_t i : llvm::seq(inputShape.size())) {
auto lowPad = pad[i * 2];
auto highPad = pad[i * 2 + 1];
if (ShapedType::isDynamic(inputShape[i]))
@@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
- return builder.createOrFold<arith::IndexCastOp>(
- builder.getIndexType(),
- builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr)));
+ return builder.create<arith::ConstantIndexOp>(attr);
}
// Calculating the output width/height using the formula:
// H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1
// W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1
-static mlir::Value getConvOutputDim(Location loc, Value inputDim,
- int64_t padBeforeAttr, int64_t padAfterAttr,
- Value kernelDim, int64_t strideAttr,
- int64_t dilationAttr, Type inputETy,
- OpBuilder &rewriter) {
+static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim,
+ int64_t padBeforeAttr,
+ int64_t padAfterAttr, Value kernelDim,
+ int64_t strideAttr,
+ int64_t dilationAttr,
+ OpBuilder &rewriter) {
ImplicitLocOpBuilder builder(loc, rewriter);
auto one = rewriter.create<arith::ConstantOp>(
loc, IntegerAttr::get(inputDim.getType(), 1));
@@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv(
ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims,
ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) {
ShapedType inputTy = cast<ShapedType>(input.getType());
- Type inputETy = inputTy.getElementType();
int64_t inputRank = inputTy.getRank();
SmallVector<Value> dynDims;
@@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv(
rewriter.create<tensor::DimOp>(loc, weight, kernelDim);
// H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y)
dynDims[inputDim] =
- getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim,
- stride, dilation, inputETy, rewriter);
+ getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom,
+ kernelDynDim, stride, dilation, rewriter);
}
}
@@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> {
public:
using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern;
+ // Compute the dynamic output sizes of the maxpool operation.
+ static SmallVector<Value>
+ computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) {
+ TensorType resultTy = op.getType();
+ Location loc = op.getLoc();
+
+ TypedValue<TensorType> input = op.getInput();
+ ArrayRef<int64_t> kernel = op.getKernel();
+ ArrayRef<int64_t> pad = op.getPad();
+ ArrayRef<int64_t> stride = op.getStride();
+
+ SmallVector<Value> dynamicDims;
+
+ // Batch dimension
+ if (resultTy.isDynamicDim(0))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0));
+
+ // Height/width dimensions
+ for (int64_t dim : {1, 2}) {
+ if (!resultTy.isDynamicDim(dim))
+ continue;
+
+ // Index into the attribute arrays
+ int64_t index = dim - 1;
+
+ // Input height/width
+ Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim);
+
+ // Kernel height/width
+ Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]);
+
+ // Output height/width
+ Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2],
+ pad[index * 2 + 1], khw, stride[index],
+ /*dilationAttr=*/1, rewriter);
+ dynamicDims.push_back(ohw);
+ }
+
+ // Channel dimension
+ if (resultTy.isDynamicDim(3))
+ dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3));
+
+ return dynamicDims;
+ }
+
LogicalResult matchAndRewrite(tosa::MaxPool2dOp op,
PatternRewriter &rewriter) const final {
Location loc = op.getLoc();
- Value input = op.getInput();
- ShapedType inputTy = cast<ShapedType>(input.getType());
+ TypedValue<TensorType> input = op.getInput();
+ ShapedType inputTy = input.getType();
- ShapedType resultTy = cast<ShapedType>(op.getType());
+ ShapedType resultTy = op.getType();
Type resultETy = inputTy.getElementType();
- auto dynamicDimsOr =
- checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()});
- if (!dynamicDimsOr.has_value())
- return failure();
- SmallVector<Value> dynamicDims = *dynamicDimsOr;
+ SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter);
// Determine what the initial value needs to be for the max pool op.
TypedAttr initialAttr;
@@ -721,6 +762,7 @@ public:
pad.resize(2, 0);
llvm::append_range(pad, op.getPad());
pad.resize(pad.size() + 2, 0);
+
Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter);
Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr);
@@ -736,9 +778,7 @@ public:
loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{emptyTensor})
+ rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor)
.result();
Value fakeWindowDims =