diff options
author | Alex MacLean <amaclean@nvidia.com> | 2024-04-23 08:56:39 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-23 08:56:39 -0700 |
commit | df608051234c256f1dc2c89f30afd034706c2c2e (patch) | |
tree | 11e5882cfcf1a7881dda4b31664c24403157df6d | |
parent | f426be195a08874686d01783bbc490295bf4afb2 (diff) |
[NVPTX] Improve support for rsqrt.approx (#89417)
Complete support for rsqrt.approx with rsqrt.approx.f64 ([PTX ISA
9.7.3.17. Floating Point Instructions:
rsqrt.approx.ftz.f64](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64)).
Additionally, add support for folding `sqrt` into `rsqrt`, with an
optional flag to disable.
-rw-r--r-- | llvm/include/llvm/IR/IntrinsicsNVVM.td | 2 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp | 6 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXInstrInfo.td | 1 | ||||
-rw-r--r-- | llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 25 | ||||
-rw-r--r-- | llvm/test/CodeGen/NVPTX/rsqrt-opt.ll | 75 | ||||
-rw-r--r-- | llvm/test/CodeGen/NVPTX/rsqrt.ll | 35 |
7 files changed, 145 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 726cea004606..0a9139e0062b 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1003,6 +1003,8 @@ let TargetPrefix = "nvvm" in { def int_nvvm_rsqrt_approx_ftz_f : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_f">, DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>; + def int_nvvm_rsqrt_approx_ftz_d : ClangBuiltin<"__nvvm_rsqrt_approx_ftz_d">, + DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty], [IntrNoMem]>; def int_nvvm_rsqrt_approx_f : ClangBuiltin<"__nvvm_rsqrt_approx_f">, DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>; def int_nvvm_rsqrt_approx_d : ClangBuiltin<"__nvvm_rsqrt_approx_d">, diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index a362709c98ef..595395bb1b4b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -30,6 +30,10 @@ using namespace llvm; #define DEBUG_TYPE "nvptx-isel" #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection" +static cl::opt<bool> + EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden, + cl::desc("Enable reciprocal sqrt optimization")); + /// createNVPTXISelDag - This pass converts a legalized DAG into a /// NVPTX-specific DAG, ready for instruction scheduling. FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM, @@ -74,6 +78,8 @@ bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const { return TL->allowUnsafeFPMath(*MF); } +bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; } + /// Select - Select instructions not customized! Used for /// expanded, promoted and normal instructions. void NVPTXDAGToDAGISel::Select(SDNode *N) { diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h index 10822f87cef3..7a7774744bc7 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h @@ -36,6 +36,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel { bool useF32FTZ() const; bool allowFMA() const; bool allowUnsafeFPMath() const; + bool doRsqrtOpt() const; public: static char ID; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 931292c7fd60..897ee89323f0 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -142,6 +142,7 @@ def hasLDU : Predicate<"Subtarget->hasLDU()">; def doF32FTZ : Predicate<"useF32FTZ()">; def doNoF32FTZ : Predicate<"!useF32FTZ()">; +def doRsqrtOpt : Predicate<"doRsqrtOpt()">; def doMulWide : Predicate<"doMulWide">; diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ec9170b4e41e..5f6e28283c5d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1171,11 +1171,36 @@ def : Pat<(int_nvvm_sqrt_f Float32Regs:$a), def INT_NVVM_RSQRT_APPROX_FTZ_F : F_MATH_1<"rsqrt.approx.ftz.f32 \t$dst, $src0;", Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_ftz_f>; +def INT_NVVM_RSQRT_APPROX_FTZ_D + : F_MATH_1<"rsqrt.approx.ftz.f64 \t$dst, $src0;", Float64Regs, Float64Regs, + int_nvvm_rsqrt_approx_ftz_d>; + def INT_NVVM_RSQRT_APPROX_F : F_MATH_1<"rsqrt.approx.f32 \t$dst, $src0;", Float32Regs, Float32Regs, int_nvvm_rsqrt_approx_f>; def INT_NVVM_RSQRT_APPROX_D : F_MATH_1<"rsqrt.approx.f64 \t$dst, $src0;", Float64Regs, Float64Regs, int_nvvm_rsqrt_approx_d>; +// 1.0f / sqrt_approx -> rsqrt_approx +def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_f Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + Requires<[doRsqrtOpt]>; +def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_approx_ftz_f Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + Requires<[doRsqrtOpt]>; +// same for int_nvvm_sqrt_f when non-precision sqrt is requested +def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>; +def: Pat<(fdiv FloatConst1, (int_nvvm_sqrt_f Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>; + +def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_F Float32Regs:$a)>, + Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doNoF32FTZ]>; +def: Pat<(fdiv FloatConst1, (fsqrt Float32Regs:$a)), + (INT_NVVM_RSQRT_APPROX_FTZ_F Float32Regs:$a)>, + Requires<[doRsqrtOpt, do_SQRTF32_APPROX, doF32FTZ]>; // // Add // diff --git a/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll new file mode 100644 index 000000000000..9dda6075a23c --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/rsqrt-opt.ll @@ -0,0 +1,75 @@ +; RUN: llc < %s -march=nvptx64 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-NOOPT +; RUN: llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-OPT,CHECK-SQRT-OPT +; RUN: llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | FileCheck %s --check-prefixes CHECK,CHECK-APPROX-NOOPT,CHECK-SQRT-NOOPT +; +; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-prec-sqrtf32=0 | %ptxas-verify %} +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -nvptx-rsqrt-approx-opt=0 | %ptxas-verify %} + + +; CHECK-LABEL: .func{{.*}}test1 +define float @test1(float %in) local_unnamed_addr { +; CHECK-APPROX-OPT: rsqrt.approx.f32 +; CHECK-APPROX-NOOPT: sqrt.approx.f32 +; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32 + %sqrt = tail call float @llvm.nvvm.sqrt.approx.f(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} +; CHECK-LABEL: .func{{.*}}test2 +define float @test2(float %in) local_unnamed_addr { +; CHECK-APPROX-OPT: rsqrt.approx.ftz.f32 +; CHECK-APPROX-NOOPT: sqrt.approx.ftz.f32 +; CHECK-APPROX-NOOPT-NEXT: rcp.rn.f32 + %sqrt = tail call float @llvm.nvvm.sqrt.approx.ftz.f(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} + +; CHECK-LABEL: .func{{.*}}test3 +define float @test3(float %in) local_unnamed_addr { +; CHECK-SQRT-OPT: rsqrt.approx.f32 +; CHECK-SQRT-NOOPT: sqrt.rn.f32 +; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32 + %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} + +; CHECK-LABEL: .func{{.*}}test4 +define float @test4(float %in) local_unnamed_addr #0 { +; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32 +; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32 +; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32 + %sqrt = tail call float @llvm.nvvm.sqrt.f(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} + +; CHECK-LABEL: .func{{.*}}test5 +define float @test5(float %in) local_unnamed_addr { +; CHECK-SQRT-OPT: rsqrt.approx.f32 +; CHECK-SQRT-NOOPT: sqrt.rn.f32 +; CHECK-SQRT-NOOPT-NEXT: rcp.rn.f32 + %sqrt = tail call float @llvm.sqrt.f32(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} + +; CHECK-LABEL: .func{{.*}}test6 +define float @test6(float %in) local_unnamed_addr #0 { +; CHECK-SQRT-OPT: rsqrt.approx.ftz.f32 +; CHECK-SQRT-NOOPT: sqrt.rn.ftz.f32 +; CHECK-SQRT-NOOPT-NEXT: rcp.rn.ftz.f32 + %sqrt = tail call float @llvm.sqrt.f32(float %in) + %rsqrt = fdiv float 1.0, %sqrt + ret float %rsqrt +} + + +declare float @llvm.nvvm.sqrt.f(float) +declare float @llvm.nvvm.sqrt.approx.f(float) +declare float @llvm.nvvm.sqrt.approx.ftz.f(float) +declare float @llvm.sqrt.f32(float) + +attributes #0 = { "denormal-fp-math-f32" = "preserve-sign" } diff --git a/llvm/test/CodeGen/NVPTX/rsqrt.ll b/llvm/test/CodeGen/NVPTX/rsqrt.ll new file mode 100644 index 000000000000..c7367245c532 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/rsqrt.ll @@ -0,0 +1,35 @@ +; RUN: llc < %s -march=nvptx64 | FileCheck %s +; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} + +; CHECK-LABEL: .func{{.*}}test1 +define float @test1(float %in) local_unnamed_addr { +; CHECK: rsqrt.approx.f32 + %call = call float @llvm.nvvm.rsqrt.approx.f(float %in) + ret float %call +} + +; CHECK-LABEL: .func{{.*}}test2 +define double @test2(double %in) local_unnamed_addr { +; CHECK: rsqrt.approx.f64 + %call = call double @llvm.nvvm.rsqrt.approx.d(double %in) + ret double %call +} + +; CHECK-LABEL: .func{{.*}}test3 +define float @test3(float %in) local_unnamed_addr { +; CHECK: rsqrt.approx.ftz.f32 + %call = tail call float @llvm.nvvm.rsqrt.approx.ftz.f(float %in) + ret float %call +} + +; CHECK-LABEL: .func{{.*}}test4 +define double @test4(double %in) local_unnamed_addr { +; CHECK: rsqrt.approx.ftz.f64 + %call = tail call double @llvm.nvvm.rsqrt.approx.ftz.d(double %in) + ret double %call +} + +declare float @llvm.nvvm.rsqrt.approx.ftz.f(float) +declare double @llvm.nvvm.rsqrt.approx.ftz.d(double) +declare float @llvm.nvvm.rsqrt.approx.f(float) +declare double @llvm.nvvm.rsqrt.approx.d(double) |