summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPaschalis Mpeis <Paschalis.Mpeis@arm.com>2024-01-16 13:52:11 +0000
committerPaschalis Mpeis <Paschalis.Mpeis@arm.com>2024-01-23 09:48:49 +0000
commit2a85ed13c10f877d2f72daa5e7c22fbccdd5fa54 (patch)
tree3d639ffacc4c934b42d6baaa46017f3e8befed5a
parent3c246efd04210af56ab6ce960b98283ec5bc7c30 (diff)
Pass replace-with-veclib only replaces to veclib calls when their cost is not found to be higher than the cost of the original instruction.
-rw-r--r--llvm/lib/CodeGen/ReplaceWithVeclib.cpp81
-rw-r--r--llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll2
-rw-r--r--llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll2
3 files changed, 71 insertions, 14 deletions
diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
index 432c63fb65f4..97c5b9519ea9 100644
--- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
+++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp
@@ -9,6 +9,8 @@
// Replaces LLVM IR instructions with vector operands (i.e., the frem
// instruction or calls to LLVM intrinsics) with matching calls to functions
// from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
+// This happens only when the cost of calling the vector library is not found to
+// be more than the cost of the original instruction.
//
//===----------------------------------------------------------------------===//
@@ -20,12 +22,16 @@
#include "llvm/Analysis/GlobalsModRef.h"
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/VectorUtils.h"
#include "llvm/CodeGen/Passes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/VFABIDemangler.h"
+#include "llvm/Support/InstructionCost.h"
#include "llvm/Support/TypeSize.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -96,15 +102,55 @@ static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
Replacement->copyFastMathFlags(&I);
}
+/// Returns whether the vector library call \p TLIFunc costs more than the
+/// original instruction \p I.
+static bool isVeclibCallSlower(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Instruction &I,
+ VectorType *VectorTy, CallInst *CI,
+ Function *TLIFunc) {
+ SmallVector<Type *, 4> OpTypes;
+ for (auto &Op : CI ? CI->args() : I.operands())
+ OpTypes.push_back(Op->getType());
+
+ TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
+ InstructionCost DefaultCost;
+ if (CI) {
+ FastMathFlags FMF;
+ if (auto *FPMO = dyn_cast<FPMathOperator>(CI))
+ FMF = FPMO->getFastMathFlags();
+
+ SmallVector<const Value *> Args(CI->args());
+ IntrinsicCostAttributes CostAttrs(CI->getIntrinsicID(), VectorTy, Args,
+ OpTypes, FMF,
+ dyn_cast<IntrinsicInst>(CI));
+ DefaultCost = TTI.getIntrinsicInstrCost(CostAttrs, CostKind);
+ } else {
+ assert((I.getOpcode() == Instruction::FRem) && "Only FRem is supported");
+ auto Op2Info = TTI.getOperandInfo(I.getOperand(1));
+ SmallVector<const Value *, 4> OpValues(I.operand_values());
+ DefaultCost = TTI.getArithmeticInstrCost(
+ I.getOpcode(), VectorTy, CostKind,
+ {TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
+ Op2Info, OpValues, &I);
+ }
+
+ InstructionCost VecLibCost =
+ TTI.getCallInstrCost(TLIFunc, VectorTy, OpTypes, CostKind);
+ return VecLibCost > DefaultCost;
+}
+
/// Returns true when successfully replaced \p I with a suitable function taking
-/// vector arguments, based on available mappings in the \p TLI. Currently only
-/// works when \p I is a call to vectorized intrinsic or the frem instruction.
+/// vector arguments, based on available mappings in the \p TLI and costs.
+/// Currently only works when \p I is a call to vectorized intrinsic or the frem
+/// instruction.
static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI,
Instruction &I) {
// At the moment VFABI assumes the return type is always widened unless it is
// a void type.
- auto *VTy = dyn_cast<VectorType>(I.getType());
- ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
+ auto *VectorTy = dyn_cast<VectorType>(I.getType());
+ ElementCount EC(VectorTy ? VectorTy->getElementCount()
+ : ElementCount::getFixed(0));
// Compute the argument types of the corresponding scalar call and the scalar
// function name. For calls, it additionally finds the function to replace
@@ -125,9 +171,10 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
ScalarArgTypes.push_back(VectorArgTy->getElementType());
// When return type is void, set EC to the first vector argument, and
// disallow vector arguments with different ECs.
- if (EC.isZero())
+ if (EC.isZero()) {
EC = VectorArgTy->getElementCount();
- else if (EC != VectorArgTy->getElementCount())
+ VectorTy = VectorArgTy;
+ } else if (EC != VectorArgTy->getElementCount())
return false;
} else
// Exit when it is supposed to be a vector argument but it isn't.
@@ -139,8 +186,8 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
: Intrinsic::getName(IID).str();
} else {
- assert(VTy && "Return type must be a vector");
- auto *ScalarTy = VTy->getScalarType();
+ assert(VectorTy && "Return type must be a vector");
+ auto *ScalarTy = VectorTy->getScalarType();
LibFunc Func;
if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
return false;
@@ -200,6 +247,9 @@ static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
VD->getVectorFnName(), FuncToReplace);
+ if (isVeclibCallSlower(TLI, TTI, I, VectorTy, CI, TLIFunc))
+ return false;
+
replaceWithTLIFunction(I, *OptInfo, TLIFunc);
LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
<< "` with call to `" << TLIFunc->getName() << "`.\n");
@@ -220,13 +270,14 @@ static bool isSupportedInstruction(Instruction *I) {
return false;
}
-static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
+static bool runImpl(const TargetLibraryInfo &TLI,
+ const TargetTransformInfo &TTI, Function &F) {
bool Changed = false;
SmallVector<Instruction *> ReplacedCalls;
for (auto &I : instructions(F)) {
if (!isSupportedInstruction(&I))
continue;
- if (replaceWithCallToVeclib(TLI, I)) {
+ if (replaceWithCallToVeclib(TLI, TTI, I)) {
ReplacedCalls.push_back(&I);
Changed = true;
}
@@ -244,7 +295,8 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
PreservedAnalyses ReplaceWithVeclib::run(Function &F,
FunctionAnalysisManager &AM) {
const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
- auto Changed = runImpl(TLI, F);
+ const TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+ auto Changed = runImpl(TLI, TTI, F);
if (Changed) {
LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
<< NumCallsReplaced << "\n");
@@ -252,6 +304,7 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
PreservedAnalyses PA;
PA.preserveSet<CFGAnalyses>();
PA.preserve<TargetLibraryAnalysis>();
+ PA.preserve<TargetIRAnalysis>();
PA.preserve<ScalarEvolutionAnalysis>();
PA.preserve<LoopAccessAnalysis>();
PA.preserve<DemandedBitsAnalysis>();
@@ -269,13 +322,17 @@ PreservedAnalyses ReplaceWithVeclib::run(Function &F,
bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
const TargetLibraryInfo &TLI =
getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
- return runImpl(TLI, F);
+ const TargetTransformInfo &TTI =
+ getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+ return runImpl(TLI, TTI, F);
}
void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
AU.setPreservesCFG();
AU.addRequired<TargetLibraryInfoWrapperPass>();
+ AU.addRequired<TargetTransformInfoWrapperPass>();
AU.addPreserved<TargetLibraryInfoWrapperPass>();
+ AU.addPreserved<TargetTransformInfoWrapperPass>();
AU.addPreserved<ScalarEvolutionWrapperPass>();
AU.addPreserved<AAResultsWrapperPass>();
AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
index 758df0493cc5..d3e1ae338f2c 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-armpl.ll
@@ -428,7 +428,7 @@ define <vscale x 4 x float> @llvm_sin_vscale_f32(<vscale x 4 x float> %in) #0 {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: define <2 x double> @frem_f64
; CHECK-SAME: (<2 x double> [[IN:%.*]]) {
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @armpl_vfmodq_f64(<2 x double> [[IN]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in
diff --git a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
index f408df570fdc..69b16b02adaa 100644
--- a/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
+++ b/llvm/test/CodeGen/AArch64/replace-with-veclib-sleef.ll
@@ -386,7 +386,7 @@ define <4 x float> @llvm_trunc_f32(<4 x float> %in) {
define <2 x double> @frem_f64(<2 x double> %in) {
; CHECK-LABEL: @frem_f64(
-; CHECK-NEXT: [[TMP1:%.*]] = call <2 x double> @_ZGVnN2vv_fmod(<2 x double> [[IN:%.*]], <2 x double> [[IN]])
+; CHECK-NEXT: [[TMP1:%.*]] = frem <2 x double> [[IN:%.*]], [[IN]]
; CHECK-NEXT: ret <2 x double> [[TMP1]]
;
%1= frem <2 x double> %in, %in