diff options
Diffstat (limited to 'mlir/lib/Conversion/TosaToLinalg')
-rw-r--r-- | mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 12 | ||||
-rw-r--r-- | mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp | 88 |
2 files changed, 72 insertions, 28 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 7c477f2e1412..d8dd1c93722b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -766,11 +766,15 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'then' region of 'scf.if' auto emitThenRegion = [&](OpBuilder &opBuilder, Location loc) { + // It is not safe to cache constants across regions. + // New constants could potentially violate dominance requirements. + IndexPool localPool; + // Emit 'tensor.empty' op SmallVector<OpFoldResult> outputTensorShape; for (auto index : llvm::seq<int64_t>(0, rank)) { auto size = index == dim ? targetSize - : getOrFoldTensorDim(rewriter, loc, indexPool, + : getOrFoldTensorDim(rewriter, loc, localPool, operand, index); outputTensorShape.push_back(size); } @@ -812,9 +816,9 @@ static Value broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, IndexPool &indexPool, Value operand, ArrayRef<OpFoldResult> targetShape, ArrayRef<Value> masterOperands) { - size_t rank = operand.getType().cast<RankedTensorType>().getRank(); - assert(targetShape.size() == rank); - assert(masterOperands.size() == rank); + int64_t rank = operand.getType().cast<RankedTensorType>().getRank(); + assert((int64_t)targetShape.size() == rank); + assert((int64_t)masterOperands.size() == rank); for (auto index : llvm::seq<int64_t>(0, rank)) operand = broadcastDynamicDimension(rewriter, loc, indexPool, operand, index, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 3f39cbf03a9a..8fb8d1648656 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -26,6 +26,8 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" + #include <numeric> #include <type_traits> @@ -34,7 +36,7 @@ using namespace mlir::tosa; static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, TypedAttr padAttr, OpBuilder &rewriter) { - // Input should be padded if necessary. + // Input should be padded only if necessary. if (llvm::all_of(pad, [](int64_t p) { return p == 0; })) return input; @@ -47,7 +49,7 @@ static mlir::Value applyPad(Location loc, Value input, ArrayRef<int64_t> pad, SmallVector<int64_t, 4> paddedShape; SmallVector<OpFoldResult, 8> lowIndices; SmallVector<OpFoldResult, 8> highIndices; - for (int i = 0, s = inputShape.size(); i < s; i++) { + for (size_t i : llvm::seq(inputShape.size())) { auto lowPad = pad[i * 2]; auto highPad = pad[i * 2 + 1]; if (ShapedType::isDynamic(inputShape[i])) @@ -131,20 +133,19 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter, static mlir::Value reifyConstantDim(int64_t attr, ImplicitLocOpBuilder &builder) { - return builder.createOrFold<arith::IndexCastOp>( - builder.getIndexType(), - builder.create<arith::ConstantOp>(builder.getI64IntegerAttr(attr))); + return builder.create<arith::ConstantIndexOp>(attr); } // Calculating the output width/height using the formula: // H = ((IH+pad_top+pad_bottom-(dilation_y*(KH-1)+1))/stride_y)+1 // W = ((IW+pad_left+pad_right-(dilation_x*(KW-1)+1))/stride_x)+1 -static mlir::Value getConvOutputDim(Location loc, Value inputDim, - int64_t padBeforeAttr, int64_t padAfterAttr, - Value kernelDim, int64_t strideAttr, - int64_t dilationAttr, Type inputETy, - OpBuilder &rewriter) { +static mlir::Value getConvOrPoolOutputDim(Location loc, Value inputDim, + int64_t padBeforeAttr, + int64_t padAfterAttr, Value kernelDim, + int64_t strideAttr, + int64_t dilationAttr, + OpBuilder &rewriter) { ImplicitLocOpBuilder builder(loc, rewriter); auto one = rewriter.create<arith::ConstantOp>( loc, IntegerAttr::get(inputDim.getType(), 1)); @@ -171,7 +172,6 @@ static SmallVector<Value> inferDynamicDimsForConv( ArrayRef<int64_t> dilationAttr, ArrayRef<int64_t> inputSizeDims, ArrayRef<int64_t> kernelSizeDims, OpBuilder &rewriter) { ShapedType inputTy = cast<ShapedType>(input.getType()); - Type inputETy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); SmallVector<Value> dynDims; @@ -190,8 +190,8 @@ static SmallVector<Value> inferDynamicDimsForConv( rewriter.create<tensor::DimOp>(loc, weight, kernelDim); // H = F(IH, pad_top, pad_bottom, dilation_y, KH, stride_y) dynDims[inputDim] = - getConvOutputDim(loc, initDynDim, padTop, padBottom, kernelDynDim, - stride, dilation, inputETy, rewriter); + getConvOrPoolOutputDim(loc, initDynDim, padTop, padBottom, + kernelDynDim, stride, dilation, rewriter); } } @@ -685,20 +685,61 @@ class MaxPool2dConverter : public OpRewritePattern<tosa::MaxPool2dOp> { public: using OpRewritePattern<tosa::MaxPool2dOp>::OpRewritePattern; + // Compute the dynamic output sizes of the maxpool operation. + static SmallVector<Value> + computeDynamicOutputSizes(tosa::MaxPool2dOp op, PatternRewriter &rewriter) { + TensorType resultTy = op.getType(); + Location loc = op.getLoc(); + + TypedValue<TensorType> input = op.getInput(); + ArrayRef<int64_t> kernel = op.getKernel(); + ArrayRef<int64_t> pad = op.getPad(); + ArrayRef<int64_t> stride = op.getStride(); + + SmallVector<Value> dynamicDims; + + // Batch dimension + if (resultTy.isDynamicDim(0)) + dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); + + // Height/width dimensions + for (int64_t dim : {1, 2}) { + if (!resultTy.isDynamicDim(dim)) + continue; + + // Index into the attribute arrays + int64_t index = dim - 1; + + // Input height/width + Value ihw = rewriter.create<tensor::DimOp>(loc, input, dim); + + // Kernel height/width + Value khw = rewriter.create<arith::ConstantIndexOp>(loc, kernel[index]); + + // Output height/width + Value ohw = getConvOrPoolOutputDim(loc, ihw, pad[index * 2], + pad[index * 2 + 1], khw, stride[index], + /*dilationAttr=*/1, rewriter); + dynamicDims.push_back(ohw); + } + + // Channel dimension + if (resultTy.isDynamicDim(3)) + dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 3)); + + return dynamicDims; + } + LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, PatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value input = op.getInput(); - ShapedType inputTy = cast<ShapedType>(input.getType()); + TypedValue<TensorType> input = op.getInput(); + ShapedType inputTy = input.getType(); - ShapedType resultTy = cast<ShapedType>(op.getType()); + ShapedType resultTy = op.getType(); Type resultETy = inputTy.getElementType(); - auto dynamicDimsOr = - checkHasDynamicBatchDims(rewriter, op, {input, op.getOutput()}); - if (!dynamicDimsOr.has_value()) - return failure(); - SmallVector<Value> dynamicDims = *dynamicDimsOr; + SmallVector<Value> dynamicDims = computeDynamicOutputSizes(op, rewriter); // Determine what the initial value needs to be for the max pool op. TypedAttr initialAttr; @@ -721,6 +762,7 @@ public: pad.resize(2, 0); llvm::append_range(pad, op.getPad()); pad.resize(pad.size() + 2, 0); + Value paddedInput = applyPad(loc, input, pad, initialAttr, rewriter); Value initialValue = rewriter.create<arith::ConstantOp>(loc, initialAttr); @@ -736,9 +778,7 @@ public: loc, resultTy.getShape(), resultTy.getElementType(), dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{emptyTensor}) + rewriter.create<linalg::FillOp>(loc, initialValue, emptyTensor) .result(); Value fakeWindowDims = |