summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlex MacLean <amaclean@nvidia.com>2024-04-23 08:56:39 -0700
committerGitHub <noreply@github.com>2024-04-23 08:56:39 -0700
commitdf608051234c256f1dc2c89f30afd034706c2c2e (patch)
tree11e5882cfcf1a7881dda4b31664c24403157df6d
parentf426be195a08874686d01783bbc490295bf4afb2 (diff)
[NVPTX] Improve support for rsqrt.approx (#89417)
Complete support for rsqrt.approx with rsqrt.approx.f64 ([PTX ISA 9.7.3.17. Floating Point Instructions: rsqrt.approx.ftz.f64](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64)). Additionally, add support for folding `sqrt` into `rsqrt`, with an optional flag to disable.
-rw-r--r--llvm/include/llvm/IR/IntrinsicsNVVM.td2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp6
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXInstrInfo.td1
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td25
-rw-r--r--llvm/test/CodeGen/NVPTX/rsqrt-opt.ll75
-rw-r--r--llvm/test/CodeGen/NVPTX/rsqrt.ll35
7 files changed, 145 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 726cea004606..0a9139e0062b 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
+ def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">,
+ DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>;
def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">,
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index a362709c98ef..595395bb1b4b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -30,6 +30,10 @@ using namespace llvm;
#define DEBUG_TYPE "nvptx-isel"
#define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
+static cl::opt<bool>
+ EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
+ cl::desc("Enable reciprocal sqrt optimization"));
+
/// createNVPTXISelDag - This pass converts a legalized DAG into a
/// NVPTX-specific DAG, ready for instruction scheduling.
FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
@@ -74,6 +78,8 @@ bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
return TL->allowUnsafeFPMath(*MF);
}
+bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
+
/// Select - Select instructions not customized! Used for
/// expanded, promoted and normal instructions.
void NVPTXDAGToDAGISel::Select(SDNode *N) {
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
index 10822f87cef3..7a7774744bc7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
bool useF32FTZ() const;
bool allowFMA() const;
bool allowUnsafeFPMath() const;
+ bool doRsqrtOpt() const;
public:
static char ID;
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 931292c7fd60..897ee89323f0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">;
def doF32FTZ : Predicate<"useF32FTZ()">;
def doNoF32FTZ : Predicate<"!useF32FTZ()">;
+def doRsqrtOpt : Predicate<"doRsqrtOpt()">;
def doMulWide : Predicate<"doMulWide">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index ec9170b4e41e..5f6e28283c5d 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a),
def INT_NVVM_RSQRT_APPROX_FTZ_F
: F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs,
int_nvvm_rsqrt_approx_ftz_f>;
+def INT_NVVM_RSQRT_APPROX_FTZ_D
+ : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs,
+ int_nvvm_rsqrt_approx_ftz_d>;
+
def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;",
Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>;
def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;",
Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>;
+// 1.0f / sqrt_approx -> rsqrt_approx
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt]>;
+// same for int_nvvm_sqrt_f when non-precision sqrt is requested
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
+
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>;
+def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)),
+ (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>,
+ Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>;
//
// Add
//
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
new file mode 100644
index 000000000000..9dda6075a23c
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll
@@ -0,0 +1,75 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT
+; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT
+; RUN: llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT
+;
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %}
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | %ptxas-verify %}
+
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+; CHECK-LABEL: .func{{.*}}test2
+define float @test2(float %in) local_unnamed_addr {
+; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32
+; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define float @test4(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+ %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test5
+define float @test5(float %in) local_unnamed_addr {
+; CHECK-SQRT-OPT: rsqrt.approx.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32
+ %sqrt = tail call float @llvm.sqrt.f32(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+; CHECK-LABEL: .func{{.*}}test6
+define float @test6(float %in) local_unnamed_addr #0 {
+; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32
+; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32
+; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32
+ %sqrt = tail call float @llvm.sqrt.f32(float %in)
+ %rsqrt = fdiv float 1.0, %sqrt
+ ret float %rsqrt
+}
+
+
+declare float @llvm.nvvm.sqrt.f(float)
+declare float @llvm.nvvm.sqrt.approx.f(float)
+declare float @llvm.nvvm.sqrt.approx.ftz.f(float)
+declare float @llvm.sqrt.f32(float)
+
+attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" }
diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll
new file mode 100644
index 000000000000..c7367245c532
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll
@@ -0,0 +1,35 @@
+; RUN: llc < %s -march=nvptx64 | FileCheck %s
+; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %}
+
+; CHECK-LABEL: .func{{.*}}test1
+define float @test1(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f32
+ %call = call float @llvm.nvvm.rsqrt.approx.f(float %in)
+ ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test2
+define double @test2(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.f64
+ %call = call double @llvm.nvvm.rsqrt.approx.d(double %in)
+ ret double %call
+}
+
+; CHECK-LABEL: .func{{.*}}test3
+define float @test3(float %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f32
+ %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in)
+ ret float %call
+}
+
+; CHECK-LABEL: .func{{.*}}test4
+define double @test4(double %in) local_unnamed_addr {
+; CHECK: rsqrt.approx.ftz.f64
+ %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in)
+ ret double %call
+}
+
+declare float @llvm.nvvm.rsqrt.approx.ftz.f(float)
+declare double @llvm.nvvm.rsqrt.approx.ftz.d(double)
+declare float @llvm.nvvm.rsqrt.approx.f(float)
+declare double @llvm.nvvm.rsqrt.approx.d(double)