diff options
author | Prashant Kumar <pk5561@gmail.com> | 2024-04-03 22:19:26 +0530 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-03 22:19:26 +0530 |
commit | 5b702be1e80b8733786ac48ceaf04f2936616d1b (patch) | |
tree | fd3e66b633fef53a2e8787f4aa731433054762d6 | |
parent | 17642c76023b7f421dac8e9fb176b0221e309a8a (diff) |
[mlir][math] Convert math.fpowi to math.powf in case of non constant (#87472)
Convert math.fpowi to math.powf by converting dtype of power operand to
floating point.
-rw-r--r-- | mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp | 20 | ||||
-rw-r--r-- | mlir/test/Dialect/Math/expand-math.mlir | 48 |
2 files changed, 63 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index 0b8546251350..42629e149e9f 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) { // Convert `math.fpowi` to a series of `arith.mulf` operations. // If the power is negative, we divide one by the result. // If both the base and power are zero, the result is 1. -static LogicalResult convertFPowICstOp(math::FPowIOp op, - PatternRewriter &rewriter) { +// In the case of non constant power, we convert the operation to `math.powf`. +static LogicalResult convertFPowIOp(math::FPowIOp op, + PatternRewriter &rewriter) { ImplicitLocOpBuilder b(op->getLoc(), rewriter); Value base = op.getOperand(0); Value power = op.getOperand(1); Type baseType = base.getType(); + auto convertFPowItoPowf = [&]() -> LogicalResult { + Value castPowerToFp = + rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power); + Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base, + castPowerToFp); + rewriter.replaceOp(op, res); + return success(); + }; + Attribute cstAttr; if (!matchPattern(power, m_Constant(&cstAttr))) - return failure(); + return convertFPowItoPowf(); APInt value; if (!matchPattern(cstAttr, m_ConstantInt(&value))) - return failure(); + return convertFPowItoPowf(); int64_t powerInt = value.getSExtValue(); bool isNegative = powerInt < 0; @@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { } void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { - patterns.add(convertFPowICstOp); + patterns.add(convertFPowIOp); } void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index bfcff27bd64e..3d94b55126d0 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -610,3 +610,51 @@ func.func @math_fpowi_scalar_zero(%0 : f32) -> f32 { // CHECK: return %[[RET]] : f32 // ----- + +// CHECK-LABEL: func.func @math_fpowi_to_powf_tensor +func.func @math_fpowi_to_powf_tensor(%0 : tensor<8xf32>, %1: tensor<8xi32>) -> tensor<8xf32> { + %2 = math.fpowi %0, %1 : tensor<8xf32>, tensor<8xi32> + return %2 : tensor<8xf32> +} +// CHECK-SAME: (%[[ARG0:.*]]: tensor<8xf32>, %[[ARG1:.*]]: tensor<8xi32>) -> tensor<8xf32> { +// CHECK: %[[CSTNEG1:.*]] = arith.constant dense<-1.000000e+00> : tensor<8xf32> +// CHECK: %[[CST2:.*]] = arith.constant dense<2.000000e+00> : tensor<8xf32> +// CHECK: %[[CST0:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> +// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : tensor<8xi32> to tensor<8xf32> +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : tensor<8xf32> +// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : tensor<8xf32> +// CHECK: %[[LG:.*]] = math.log %[[SQ]] : tensor<8xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : tensor<8xf32> +// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : tensor<8xf32> +// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : tensor<8xf32> +// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : tensor<8xf32> +// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : tensor<8xf32> +// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : tensor<8xi1> +// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : tensor<8xi1>, tensor<8xf32> +// CHECK: return %[[SEL]] : tensor<8xf32> + +// ----- + +// CHECK-LABEL: func.func @math_fpowi_to_powf_scalar +func.func @math_fpowi_to_powf_scalar(%0 : f32, %1: i64) -> f32 { + %2 = math.fpowi %0, %1 : f32, i64 + return %2 : f32 +} +// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: i64) -> f32 { +// CHECK: %[[CSTNEG1:.*]] = arith.constant -1.000000e+00 : f32 +// CHECK: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK: %[[CST0:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[TOFP:.*]] = arith.sitofp %[[ARG1]] : i64 to f32 +// CHECK: %[[SQ:.*]] = arith.mulf %[[ARG0]], %[[ARG0]] : f32 +// CHECK: %[[DIV:.*]] = arith.divf %[[TOFP]], %[[CST2]] : f32 +// CHECK: %[[LG:.*]] = math.log %[[SQ]] : f32 +// CHECK: %[[MUL:.*]] = arith.mulf %[[DIV]], %[[LG]] : f32 +// CHECK: %[[EXP:.*]] = math.exp %[[MUL]] : f32 +// CHECK: %[[MUL1:.*]] = arith.mulf %[[EXP]], %[[CSTNEG1]] : f32 +// CHECK: %[[REM:.*]] = arith.remf %[[TOFP]], %[[CST2]] : f32 +// CHECK: %[[CMPF:.*]] = arith.cmpf olt, %[[ARG0]], %[[CST0]] : f32 +// CHECK: %[[CMPF1:.*]] = arith.cmpf one, %[[REM]], %[[CST0]] : f32 +// CHECK: %[[AND:.*]] = arith.andi %[[CMPF1]], %[[CMPF]] : i1 +// CHECK: %[[SEL:.*]] = arith.select %[[AND]], %[[MUL1]], %[[EXP]] : f32 +// CHECK: return %[[SEL]] : f32 |