summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJohannes Reifferscheid <jreiffers@google.com>2024-02-23 13:15:08 +0100
committerGitHub <noreply@github.com>2024-02-23 13:15:08 +0100
commitbcf9826a5392f40063869c3d2b72a5cd1b87d14b (patch)
tree3831e2076350b5c5ebac5a5b72e900a51f6ff808
parent3b3d0978c334702114131e4dab549aa25b9f0ad4 (diff)
[MLIR] Expose approximation patterns for tanh/erf. (#82750)
These patterns can already be used via populateMathPolynomialApproximationPatterns, but that includes a number of other patterns that may not be needed. There are already similar functions for expansion. For now only adding tanh and erf since I have a concrete use case for these two.
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.h3
-rw-r--r--mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp10
2 files changed, 13 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index 010dde5ea738..11b2c7a7afa2 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -45,6 +45,9 @@ struct MathPolynomialApproximationOptions {
bool enableAvx2 = false;
};
+void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
+void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
+
void populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options = {});
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index 71e4e13103f5..962cb28b7c2a 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -1471,6 +1471,16 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
//----------------------------------------------------------------------------//
+void mlir::populatePolynomialApproximateTanhPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<TanhApproximation>(patterns.getContext());
+}
+
+void mlir::populatePolynomialApproximateErfPattern(
+ RewritePatternSet &patterns) {
+ patterns.add<ErfPolynomialApproximation>(patterns.getContext());
+}
+
void mlir::populateMathPolynomialApproximationPatterns(
RewritePatternSet &patterns,
const MathPolynomialApproximationOptions &options) {