diff options
author | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-03-18 13:00:14 -0500 |
---|---|---|
committer | Noah Goldstein <goldstein.w.n@gmail.com> | 2024-05-03 14:10:24 -0500 |
commit | d8428dfeb8d9a0bbb5345f96f29a4a66eb950769 (patch) | |
tree | 6e1819e7107eb519b3825fc622a3394ae8898d43 | |
parent | 285dbed147e243f416b003e150d67ffb0922ff16 (diff) |
[PatternMatching] Add generic API for matching constants using custom conditions
The new API is:
`m_CheckedInt(Lambda)`/`m_CheckedFp(Lambda)`
- Matches non-undef constants s.t `Lambda(ele)` is true for all
elements.
`m_CheckedIntAllowUndef(Lambda)`/`m_CheckedFpAllowUndef(Lambda)`
- Matches constants/undef s.t `Lambda(ele)` is true for all
elements.
The goal with these is to be able to replace the common usage of:
```
match(X, m_APInt(C)) && CustomCheck(C)
```
with
```
match(X, m_CheckedInt(C, CustomChecks);
```
The rationale if we often ignore non-splat vectors because there are
no good APIs to handle them with and its not worth increasing code
complexity for such cases.
The hope is the API creates a common method handling
scalars/splat-vecs/non-splat-vecs to essentially make this a
non-issue.
-rw-r--r-- | llvm/include/llvm/IR/PatternMatch.h | 33 | ||||
-rw-r--r-- | llvm/unittests/IR/PatternMatch.cpp | 177 |
2 files changed, 210 insertions, 0 deletions
diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h index 5da4956c54e8..5d8f5c134bb5 100644 --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -460,6 +460,39 @@ template <typename Predicate> struct apf_pred_ty : public Predicate { // /////////////////////////////////////////////////////////////////////////////// +template <typename APTy> struct custom_checkfn { + function_ref<bool(const APTy &)> CheckFn; + bool isValue(const APTy &C) { return CheckFn(C); } +}; + +/// Match an integer or vector where CheckFn(ele) for each element is true. +/// For vectors, poison elements are assumed to match. +inline cst_pred_ty<custom_checkfn<APInt>> +m_CheckedInt(function_ref<bool(const APInt &)> CheckFn) { + return cst_pred_ty<custom_checkfn<APInt>>{CheckFn}; +} + +inline api_pred_ty<custom_checkfn<APInt>> +m_CheckedInt(const APInt *&V, function_ref<bool(const APInt &)> CheckFn) { + api_pred_ty<custom_checkfn<APInt>> P(V); + P.CheckFn = CheckFn; + return P; +} + +/// Match a float or vector where CheckFn(ele) for each element is true. +/// For vectors, poison elements are assumed to match. +inline cstfp_pred_ty<custom_checkfn<APFloat>> +m_CheckedFp(function_ref<bool(const APFloat &)> CheckFn) { + return cstfp_pred_ty<custom_checkfn<APFloat>>{CheckFn}; +} + +inline apf_pred_ty<custom_checkfn<APFloat>> +m_CheckedFp(const APFloat *&V, function_ref<bool(const APFloat &)> CheckFn) { + apf_pred_ty<custom_checkfn<APFloat>> P(V); + P.CheckFn = CheckFn; + return P; +} + struct is_any_apint { bool isValue(const APInt &C) { return true; } }; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp index a25885faa3a4..d5a4a6a05687 100644 --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -611,6 +611,134 @@ TEST_F(PatternMatchTest, BitCast) { EXPECT_FALSE(m_ElementWiseBitCast(m_Value()).match(NXV2I64ToNXV4I32)); } +TEST_F(PatternMatchTest, CheckedInt) { + Type *I8Ty = IRB.getInt8Ty(); + const APInt *Res = nullptr; + + auto CheckUgt1 = [](const APInt &C) { return C.ugt(1); }; + auto CheckTrue = [](const APInt &) { return true; }; + auto CheckFalse = [](const APInt &) { return false; }; + auto CheckNonZero = [](const APInt &C) { return !C.isZero(); }; + auto CheckPow2 = [](const APInt &C) { return C.isPowerOf2(); }; + + auto DoScalarCheck = [&](int8_t Val) { + APInt APVal(8, Val); + Constant *C = ConstantInt::get(I8Ty, Val); + + Res = nullptr; + EXPECT_TRUE(m_CheckedInt(CheckTrue).match(C)); + EXPECT_TRUE(m_CheckedInt(Res, CheckTrue).match(C)); + EXPECT_EQ(*Res, APVal); + + Res = nullptr; + EXPECT_FALSE(m_CheckedInt(CheckFalse).match(C)); + EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(C)); + + Res = nullptr; + EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(CheckUgt1).match(C)); + EXPECT_EQ(CheckUgt1(APVal), m_CheckedInt(Res, CheckUgt1).match(C)); + if (CheckUgt1(APVal)) { + EXPECT_NE(Res, nullptr); + EXPECT_EQ(*Res, APVal); + } + + Res = nullptr; + EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(CheckNonZero).match(C)); + EXPECT_EQ(CheckNonZero(APVal), m_CheckedInt(Res, CheckNonZero).match(C)); + if (CheckNonZero(APVal)) { + EXPECT_NE(Res, nullptr); + EXPECT_EQ(*Res, APVal); + } + + Res = nullptr; + EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(CheckPow2).match(C)); + EXPECT_EQ(CheckPow2(APVal), m_CheckedInt(Res, CheckPow2).match(C)); + if (CheckPow2(APVal)) { + EXPECT_NE(Res, nullptr); + EXPECT_EQ(*Res, APVal); + } + + }; + + DoScalarCheck(0); + DoScalarCheck(1); + DoScalarCheck(2); + DoScalarCheck(3); + + EXPECT_FALSE(m_CheckedInt(CheckTrue).match(UndefValue::get(I8Ty))); + EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(UndefValue::get(I8Ty))); + EXPECT_EQ(Res, nullptr); + + EXPECT_FALSE(m_CheckedInt(CheckFalse).match(UndefValue::get(I8Ty))); + EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(UndefValue::get(I8Ty))); + EXPECT_EQ(Res, nullptr); + + EXPECT_FALSE(m_CheckedInt(CheckTrue).match(PoisonValue::get(I8Ty))); + EXPECT_FALSE(m_CheckedInt(Res, CheckTrue).match(PoisonValue::get(I8Ty))); + EXPECT_EQ(Res, nullptr); + + EXPECT_FALSE(m_CheckedInt(CheckFalse).match(PoisonValue::get(I8Ty))); + EXPECT_FALSE(m_CheckedInt(Res, CheckFalse).match(PoisonValue::get(I8Ty))); + EXPECT_EQ(Res, nullptr); + + auto DoVecCheckImpl = [&](ArrayRef<std::optional<int8_t>> Vals, + function_ref<bool(const APInt &)> CheckFn, + bool UndefAsPoison) { + SmallVector<Constant *> VecElems; + std::optional<bool> Okay; + bool AllSame = true; + bool HasUndef = false; + std::optional<APInt> First; + for (const std::optional<int8_t> &Val : Vals) { + if (!Val.has_value()) { + VecElems.push_back(UndefAsPoison ? PoisonValue::get(I8Ty) + : UndefValue::get(I8Ty)); + HasUndef = true; + } else { + if (!Okay.has_value()) + Okay = true; + APInt APVal(8, *Val); + if (!First.has_value()) + First = APVal; + else + AllSame &= First->eq(APVal); + Okay = *Okay && CheckFn(APVal); + VecElems.push_back(ConstantInt::get(I8Ty, *Val)); + } + } + + Constant *C = ConstantVector::get(VecElems); + EXPECT_EQ(!(HasUndef && !UndefAsPoison) && Okay.value_or(false), + m_CheckedInt(CheckFn).match(C)); + + Res = nullptr; + bool Expec = + !(HasUndef && !UndefAsPoison) && AllSame && Okay.value_or(false); + EXPECT_EQ(Expec, m_CheckedInt(Res, CheckFn).match(C)); + if (Expec) { + EXPECT_NE(Res, nullptr); + EXPECT_EQ(*Res, *First); + } + }; + auto DoVecCheck = [&](ArrayRef<std::optional<int8_t>> Vals) { + DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/false); + DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/false); + DoVecCheckImpl(Vals, CheckTrue, /*UndefAsPoison=*/true); + DoVecCheckImpl(Vals, CheckFalse, /*UndefAsPoison=*/true); + DoVecCheckImpl(Vals, CheckUgt1, /*UndefAsPoison=*/false); + DoVecCheckImpl(Vals, CheckNonZero, /*UndefAsPoison=*/false); + DoVecCheckImpl(Vals, CheckPow2, /*UndefAsPoison=*/false); + }; + + DoVecCheck({0, 1}); + DoVecCheck({1, 1}); + DoVecCheck({1, 2}); + DoVecCheck({1, std::nullopt}); + DoVecCheck({1, std::nullopt, 1}); + DoVecCheck({1, std::nullopt, 2}); + DoVecCheck({std::nullopt, std::nullopt, std::nullopt}); +} + TEST_F(PatternMatchTest, Power2) { Value *C128 = IRB.getInt32(128); Value *CNeg128 = ConstantExpr::getNeg(cast<Constant>(C128)); @@ -1397,21 +1525,58 @@ TEST_F(PatternMatchTest, VectorUndefFloat) { EXPECT_FALSE(match(VectorInfPoison, m_Finite())); EXPECT_FALSE(match(VectorNaNPoison, m_Finite())); + auto CheckTrue = [](const APFloat &) { return true; }; + EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(ScalarNaN, m_CheckedFp(CheckTrue))); + EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckTrue))); + EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckTrue))); + EXPECT_TRUE(match(VectorNaNPoison, m_CheckedFp(CheckTrue))); + + auto CheckFalse = [](const APFloat &) { return false; }; + EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(VectorZeroPoison, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(ScalarPosInf, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(ScalarNegInf, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(VectorInfPoison, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckFalse))); + EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckFalse))); + + auto CheckNonNaN = [](const APFloat &C) { return !C.isNaN(); }; + EXPECT_FALSE(match(VectorZeroUndef, m_CheckedFp(CheckNonNaN))); + EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(CheckNonNaN))); + EXPECT_TRUE(match(ScalarPosInf, m_CheckedFp(CheckNonNaN))); + EXPECT_TRUE(match(ScalarNegInf, m_CheckedFp(CheckNonNaN))); + EXPECT_FALSE(match(ScalarNaN, m_CheckedFp(CheckNonNaN))); + EXPECT_FALSE(match(VectorInfUndef, m_CheckedFp(CheckNonNaN))); + EXPECT_TRUE(match(VectorInfPoison, m_CheckedFp(CheckNonNaN))); + EXPECT_FALSE(match(VectorNaNUndef, m_CheckedFp(CheckNonNaN))); + EXPECT_FALSE(match(VectorNaNPoison, m_CheckedFp(CheckNonNaN))); + const APFloat *C; // Regardless of whether poison is allowed, // a fully undef/poison constant does not match. EXPECT_FALSE(match(ScalarUndef, m_APFloat(C))); EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidPoison(C))); EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowPoison(C))); + EXPECT_FALSE(match(ScalarUndef, m_CheckedFp(C, CheckTrue))); EXPECT_FALSE(match(VectorUndef, m_APFloat(C))); EXPECT_FALSE(match(VectorUndef, m_APFloatForbidPoison(C))); EXPECT_FALSE(match(VectorUndef, m_APFloatAllowPoison(C))); + EXPECT_FALSE(match(VectorUndef, m_CheckedFp(C, CheckTrue))); EXPECT_FALSE(match(ScalarPoison, m_APFloat(C))); EXPECT_FALSE(match(ScalarPoison, m_APFloatForbidPoison(C))); EXPECT_FALSE(match(ScalarPoison, m_APFloatAllowPoison(C))); + EXPECT_FALSE(match(ScalarPoison, m_CheckedFp(C, CheckTrue))); EXPECT_FALSE(match(VectorPoison, m_APFloat(C))); EXPECT_FALSE(match(VectorPoison, m_APFloatForbidPoison(C))); EXPECT_FALSE(match(VectorPoison, m_APFloatAllowPoison(C))); + EXPECT_FALSE(match(VectorPoison, m_CheckedFp(C, CheckTrue))); // We can always match simple constants and simple splats. C = nullptr; @@ -1432,6 +1597,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) { C = nullptr; EXPECT_TRUE(match(VectorZero, m_APFloatAllowPoison(C))); EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckTrue))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_CheckedFp(C, CheckNonNaN))); + EXPECT_TRUE(C->isZero()); // Splats with undef are never allowed. // Whether splats with poison can be matched depends on the matcher. @@ -1456,6 +1627,12 @@ TEST_F(PatternMatchTest, VectorUndefFloat) { C = nullptr; EXPECT_TRUE(match(VectorZeroPoison, m_Finite(C))); EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckTrue))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZeroPoison, m_CheckedFp(C, CheckNonNaN))); + EXPECT_TRUE(C->isZero()); } TEST_F(PatternMatchTest, FloatingPointFNeg) { |