diff options
author | Min-Yih Hsu <min.hsu@sifive.com> | 2024-02-23 11:03:36 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-23 11:03:36 -0800 |
commit | 5874874c24720dc24fde12327f81369ef4af4e0b (patch) | |
tree | e2be685825be1a29191a32f67e797432515d3274 | |
parent | f8ce460e48ccc774354df75520d00a67ddbf84c0 (diff) |
[SelectionDAG] Introducing the SelectionDAG pattern matching framework (#78654)
Akin to `llvm::PatternMatch` and `llvm::MIPatternMatch`, the
`llvm::SDPatternMatch` introduced in this patch provides a DSL-alike
framework to match SDValue / SDNode with a more succinct syntax.
-rw-r--r-- | llvm/include/llvm/CodeGen/SDPatternMatch.h | 694 | ||||
-rw-r--r-- | llvm/unittests/CodeGen/CMakeLists.txt | 1 | ||||
-rw-r--r-- | llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp | 292 |
3 files changed, 987 insertions, 0 deletions
diff --git a/llvm/include/llvm/CodeGen/SDPatternMatch.h b/llvm/include/llvm/CodeGen/SDPatternMatch.h new file mode 100644 index 000000000000..412bf42677cc --- /dev/null +++ b/llvm/include/llvm/CodeGen/SDPatternMatch.h @@ -0,0 +1,694 @@ +//==--------------- llvm/CodeGen/SDPatternMatch.h ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// Contains matchers for matching SelectionDAG nodes and values. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_SDPATTERNMATCH_H +#define LLVM_CODEGEN_SDPATTERNMATCH_H + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/CodeGen/SelectionDAG.h" +#include "llvm/CodeGen/SelectionDAGNodes.h" +#include "llvm/CodeGen/TargetLowering.h" + +namespace llvm { +namespace SDPatternMatch { + +/// MatchContext can repurpose existing patterns to behave differently under +/// a certain context. For instance, `m_Opc(ISD::ADD)` matches plain ADD nodes +/// in normal circumstances, but matches VP_ADD nodes under a custom +/// VPMatchContext. This design is meant to facilitate code / pattern reusing. +class BasicMatchContext { + const SelectionDAG *DAG; + const TargetLowering *TLI; + +public: + explicit BasicMatchContext(const SelectionDAG *DAG) + : DAG(DAG), TLI(DAG ? &DAG->getTargetLoweringInfo() : nullptr) {} + + explicit BasicMatchContext(const TargetLowering *TLI) + : DAG(nullptr), TLI(TLI) {} + + // A valid MatchContext has to implement the following functions. + + const SelectionDAG *getDAG() const { return DAG; } + + const TargetLowering *getTLI() const { return TLI; } + + /// Return true if N effectively has opcode Opcode. + bool match(SDValue N, unsigned Opcode) const { + return N->getOpcode() == Opcode; + } +}; + +template <typename Pattern, typename MatchContext> +[[nodiscard]] bool sd_context_match(SDValue N, const MatchContext &Ctx, + Pattern &&P) { + return P.match(Ctx, N); +} + +template <typename Pattern, typename MatchContext> +[[nodiscard]] bool sd_context_match(SDNode *N, const MatchContext &Ctx, + Pattern &&P) { + return sd_context_match(SDValue(N, 0), Ctx, P); +} + +template <typename Pattern> +[[nodiscard]] bool sd_match(SDNode *N, const SelectionDAG *DAG, Pattern &&P) { + return sd_context_match(N, BasicMatchContext(DAG), P); +} + +template <typename Pattern> +[[nodiscard]] bool sd_match(SDValue N, const SelectionDAG *DAG, Pattern &&P) { + return sd_context_match(N, BasicMatchContext(DAG), P); +} + +template <typename Pattern> +[[nodiscard]] bool sd_match(SDNode *N, Pattern &&P) { + return sd_match(N, nullptr, P); +} + +template <typename Pattern> +[[nodiscard]] bool sd_match(SDValue N, Pattern &&P) { + return sd_match(N, nullptr, P); +} + +// === Utilities === +struct Value_match { + SDValue MatchVal; + + Value_match() = default; + + explicit Value_match(SDValue Match) : MatchVal(Match) {} + + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + if (MatchVal) + return MatchVal == N; + return N.getNode(); + } +}; + +/// Match any valid SDValue. +inline Value_match m_Value() { return Value_match(); } + +inline Value_match m_Specific(SDValue N) { + assert(N); + return Value_match(N); +} + +struct DeferredValue_match { + SDValue &MatchVal; + + explicit DeferredValue_match(SDValue &Match) : MatchVal(Match) {} + + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + return N == MatchVal; + } +}; + +/// Similar to m_Specific, but the specific value to match is determined by +/// another sub-pattern in the same sd_match() expression. For instance, +/// We cannot match `(add V, V)` with `m_Add(m_Value(X), m_Specific(X))` since +/// `X` is not initialized at the time it got copied into `m_Specific`. Instead, +/// we should use `m_Add(m_Value(X), m_Deferred(X))`. +inline DeferredValue_match m_Deferred(SDValue &V) { + return DeferredValue_match(V); +} + +struct Opcode_match { + unsigned Opcode; + + explicit Opcode_match(unsigned Opc) : Opcode(Opc) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + return Ctx.match(N, Opcode); + } +}; + +inline Opcode_match m_Opc(unsigned Opcode) { return Opcode_match(Opcode); } + +template <unsigned NumUses, typename Pattern> struct NUses_match { + Pattern P; + + explicit NUses_match(const Pattern &P) : P(P) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + // SDNode::hasNUsesOfValue is pretty expensive when the SDNode produces + // multiple results, hence we check the subsequent pattern here before + // checking the number of value users. + return P.match(Ctx, N) && N->hasNUsesOfValue(NumUses, N.getResNo()); + } +}; + +template <typename Pattern> +inline NUses_match<1, Pattern> m_OneUse(const Pattern &P) { + return NUses_match<1, Pattern>(P); +} +template <unsigned N, typename Pattern> +inline NUses_match<N, Pattern> m_NUses(const Pattern &P) { + return NUses_match<N, Pattern>(P); +} + +inline NUses_match<1, Value_match> m_OneUse() { + return NUses_match<1, Value_match>(m_Value()); +} +template <unsigned N> inline NUses_match<N, Value_match> m_NUses() { + return NUses_match<N, Value_match>(m_Value()); +} + +struct Value_bind { + SDValue &BindVal; + + explicit Value_bind(SDValue &N) : BindVal(N) {} + + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + BindVal = N; + return true; + } +}; + +inline Value_bind m_Value(SDValue &N) { return Value_bind(N); } + +template <typename Pattern, typename PredFuncT> struct TLI_pred_match { + Pattern P; + PredFuncT PredFunc; + + TLI_pred_match(const PredFuncT &Pred, const Pattern &P) + : P(P), PredFunc(Pred) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + assert(Ctx.getTLI() && "TargetLowering is required for this pattern."); + return PredFunc(*Ctx.getTLI(), N) && P.match(Ctx, N); + } +}; + +// Explicit deduction guide. +template <typename PredFuncT, typename Pattern> +TLI_pred_match(const PredFuncT &Pred, const Pattern &P) + -> TLI_pred_match<Pattern, PredFuncT>; + +/// Match legal SDNodes based on the information provided by TargetLowering. +template <typename Pattern> inline auto m_LegalOp(const Pattern &P) { + return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { + return TLI.isOperationLegal(N->getOpcode(), + N.getValueType()); + }, + P}; +} + +/// Switch to a different MatchContext for subsequent patterns. +template <typename NewMatchContext, typename Pattern> struct SwitchContext { + const NewMatchContext &Ctx; + Pattern P; + + template <typename OrigMatchContext> + bool match(const OrigMatchContext &, SDValue N) { + return P.match(Ctx, N); + } +}; + +template <typename MatchContext, typename Pattern> +inline SwitchContext<MatchContext, Pattern> m_Context(const MatchContext &Ctx, + Pattern &&P) { + return SwitchContext<MatchContext, Pattern>{Ctx, std::move(P)}; +} + +// === Value type === +struct ValueType_bind { + EVT &BindVT; + + explicit ValueType_bind(EVT &Bind) : BindVT(Bind) {} + + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + BindVT = N.getValueType(); + return true; + } +}; + +/// Retreive the ValueType of the current SDValue. +inline ValueType_bind m_VT(EVT &VT) { return ValueType_bind(VT); } + +template <typename Pattern, typename PredFuncT> struct ValueType_match { + PredFuncT PredFunc; + Pattern P; + + ValueType_match(const PredFuncT &Pred, const Pattern &P) + : PredFunc(Pred), P(P) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + return PredFunc(N.getValueType()) && P.match(Ctx, N); + } +}; + +// Explicit deduction guide. +template <typename PredFuncT, typename Pattern> +ValueType_match(const PredFuncT &Pred, const Pattern &P) + -> ValueType_match<Pattern, PredFuncT>; + +/// Match a specific ValueType. +template <typename Pattern> +inline auto m_SpecificVT(EVT RefVT, const Pattern &P) { + return ValueType_match{[=](EVT VT) { return VT == RefVT; }, P}; +} +inline auto m_SpecificVT(EVT RefVT) { + return ValueType_match{[=](EVT VT) { return VT == RefVT; }, m_Value()}; +} + +inline auto m_Glue() { return m_SpecificVT(MVT::Glue); } +inline auto m_OtherVT() { return m_SpecificVT(MVT::Other); } + +/// Match any integer ValueTypes. +template <typename Pattern> inline auto m_IntegerVT(const Pattern &P) { + return ValueType_match{[](EVT VT) { return VT.isInteger(); }, P}; +} +inline auto m_IntegerVT() { + return ValueType_match{[](EVT VT) { return VT.isInteger(); }, m_Value()}; +} + +/// Match any floating point ValueTypes. +template <typename Pattern> inline auto m_FloatingPointVT(const Pattern &P) { + return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, P}; +} +inline auto m_FloatingPointVT() { + return ValueType_match{[](EVT VT) { return VT.isFloatingPoint(); }, + m_Value()}; +} + +/// Match any vector ValueTypes. +template <typename Pattern> inline auto m_VectorVT(const Pattern &P) { + return ValueType_match{[](EVT VT) { return VT.isVector(); }, P}; +} +inline auto m_VectorVT() { + return ValueType_match{[](EVT VT) { return VT.isVector(); }, m_Value()}; +} + +/// Match fixed-length vector ValueTypes. +template <typename Pattern> inline auto m_FixedVectorVT(const Pattern &P) { + return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, P}; +} +inline auto m_FixedVectorVT() { + return ValueType_match{[](EVT VT) { return VT.isFixedLengthVector(); }, + m_Value()}; +} + +/// Match scalable vector ValueTypes. +template <typename Pattern> inline auto m_ScalableVectorVT(const Pattern &P) { + return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, P}; +} +inline auto m_ScalableVectorVT() { + return ValueType_match{[](EVT VT) { return VT.isScalableVector(); }, + m_Value()}; +} + +/// Match legal ValueTypes based on the information provided by TargetLowering. +template <typename Pattern> inline auto m_LegalType(const Pattern &P) { + return TLI_pred_match{[](const TargetLowering &TLI, SDValue N) { + return TLI.isTypeLegal(N.getValueType()); + }, + P}; +} + +// === Patterns combinators === +template <typename... Preds> struct And { + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + return true; + } +}; + +template <typename Pred, typename... Preds> +struct And<Pred, Preds...> : And<Preds...> { + Pred P; + And(Pred &&p, Preds &&...preds) + : And<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) { + } + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + return P.match(Ctx, N) && And<Preds...>::match(Ctx, N); + } +}; + +template <typename... Preds> struct Or { + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + return false; + } +}; + +template <typename Pred, typename... Preds> +struct Or<Pred, Preds...> : Or<Preds...> { + Pred P; + Or(Pred &&p, Preds &&...preds) + : Or<Preds...>(std::forward<Preds>(preds)...), P(std::forward<Pred>(p)) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + return P.match(Ctx, N) || Or<Preds...>::match(Ctx, N); + } +}; + +template <typename... Preds> And<Preds...> m_AllOf(Preds &&...preds) { + return And<Preds...>(std::forward<Preds>(preds)...); +} + +template <typename... Preds> Or<Preds...> m_AnyOf(Preds &&...preds) { + return Or<Preds...>(std::forward<Preds>(preds)...); +} + +// === Generic node matching === +template <unsigned OpIdx, typename... OpndPreds> struct Operands_match { + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + // Returns false if there are more operands than predicates; + return N->getNumOperands() == OpIdx; + } +}; + +template <unsigned OpIdx, typename OpndPred, typename... OpndPreds> +struct Operands_match<OpIdx, OpndPred, OpndPreds...> + : Operands_match<OpIdx + 1, OpndPreds...> { + OpndPred P; + + Operands_match(OpndPred &&p, OpndPreds &&...preds) + : Operands_match<OpIdx + 1, OpndPreds...>( + std::forward<OpndPreds>(preds)...), + P(std::forward<OpndPred>(p)) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + if (OpIdx < N->getNumOperands()) + return P.match(Ctx, N->getOperand(OpIdx)) && + Operands_match<OpIdx + 1, OpndPreds...>::match(Ctx, N); + + // This is the case where there are more predicates than operands. + return false; + } +}; + +template <typename... OpndPreds> +auto m_Node(unsigned Opcode, OpndPreds &&...preds) { + return m_AllOf(m_Opc(Opcode), Operands_match<0, OpndPreds...>( + std::forward<OpndPreds>(preds)...)); +} + +/// Provide number of operands that are not chain or glue, as well as the first +/// index of such operand. +template <bool ExcludeChain> struct EffectiveOperands { + unsigned Size = 0; + unsigned FirstIndex = 0; + + explicit EffectiveOperands(SDValue N) { + const unsigned TotalNumOps = N->getNumOperands(); + FirstIndex = TotalNumOps; + for (unsigned I = 0; I < TotalNumOps; ++I) { + // Count the number of non-chain and non-glue nodes (we ignore chain + // and glue by default) and retreive the operand index offset. + EVT VT = N->getOperand(I).getValueType(); + if (VT != MVT::Glue && VT != MVT::Other) { + ++Size; + if (FirstIndex == TotalNumOps) + FirstIndex = I; + } + } + } +}; + +template <> struct EffectiveOperands<false> { + unsigned Size = 0; + unsigned FirstIndex = 0; + + explicit EffectiveOperands(SDValue N) : Size(N->getNumOperands()) {} +}; + +// === Binary operations === +template <typename LHS_P, typename RHS_P, bool Commutable = false, + bool ExcludeChain = false> +struct BinaryOpc_match { + unsigned Opcode; + LHS_P LHS; + RHS_P RHS; + + BinaryOpc_match(unsigned Opc, const LHS_P &L, const RHS_P &R) + : Opcode(Opc), LHS(L), RHS(R) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + if (sd_context_match(N, Ctx, m_Opc(Opcode))) { + EffectiveOperands<ExcludeChain> EO(N); + assert(EO.Size == 2); + return (LHS.match(Ctx, N->getOperand(EO.FirstIndex)) && + RHS.match(Ctx, N->getOperand(EO.FirstIndex + 1))) || + (Commutable && LHS.match(Ctx, N->getOperand(EO.FirstIndex + 1)) && + RHS.match(Ctx, N->getOperand(EO.FirstIndex))); + } + + return false; + } +}; + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_BinOp(unsigned Opc, const LHS &L, + const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(Opc, L, R); +} +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true> m_c_BinOp(unsigned Opc, const LHS &L, + const RHS &R) { + return BinaryOpc_match<LHS, RHS, true>(Opc, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false, true> +m_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false, true>(Opc, L, R); +} +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true, true> +m_c_ChainedBinOp(unsigned Opc, const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, true, true>(Opc, L, R); +} + +// Common binary operations +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true> m_Add(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, true>(ISD::ADD, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_Sub(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SUB, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true> m_Mul(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, true>(ISD::MUL, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_UDiv(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::UDIV, L, R); +} +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_SDiv(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SDIV, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_URem(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::UREM, L, R); +} +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_SRem(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SREM, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_Shl(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SHL, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_Sra(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SRA, L, R); +} +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_Srl(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::SRL, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true> m_FAdd(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, true>(ISD::FADD, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_FSub(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::FSUB, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, true> m_FMul(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, true>(ISD::FMUL, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_FDiv(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::FDIV, L, R); +} + +template <typename LHS, typename RHS> +inline BinaryOpc_match<LHS, RHS, false> m_FRem(const LHS &L, const RHS &R) { + return BinaryOpc_match<LHS, RHS, false>(ISD::FREM, L, R); +} + +// === Unary operations === +template <typename Opnd_P, bool ExcludeChain = false> struct UnaryOpc_match { + unsigned Opcode; + Opnd_P Opnd; + + UnaryOpc_match(unsigned Opc, const Opnd_P &Op) : Opcode(Opc), Opnd(Op) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + if (sd_context_match(N, Ctx, m_Opc(Opcode))) { + EffectiveOperands<ExcludeChain> EO(N); + assert(EO.Size == 1); + return Opnd.match(Ctx, N->getOperand(EO.FirstIndex)); + } + + return false; + } +}; + +template <typename Opnd> +inline UnaryOpc_match<Opnd> m_UnaryOp(unsigned Opc, const Opnd &Op) { + return UnaryOpc_match<Opnd>(Opc, Op); +} +template <typename Opnd> +inline UnaryOpc_match<Opnd, true> m_ChainedUnaryOp(unsigned Opc, + const Opnd &Op) { + return UnaryOpc_match<Opnd, true>(Opc, Op); +} + +template <typename Opnd> inline UnaryOpc_match<Opnd> m_ZExt(const Opnd &Op) { + return UnaryOpc_match<Opnd>(ISD::ZERO_EXTEND, Op); +} + +template <typename Opnd> inline UnaryOpc_match<Opnd> m_SExt(const Opnd &Op) { + return UnaryOpc_match<Opnd>(ISD::SIGN_EXTEND, Op); +} + +template <typename Opnd> inline UnaryOpc_match<Opnd> m_AnyExt(const Opnd &Op) { + return UnaryOpc_match<Opnd>(ISD::ANY_EXTEND, Op); +} + +template <typename Opnd> inline UnaryOpc_match<Opnd> m_Trunc(const Opnd &Op) { + return UnaryOpc_match<Opnd>(ISD::TRUNCATE, Op); +} + +// === Constants === +struct ConstantInt_match { + APInt *BindVal; + + explicit ConstantInt_match(APInt *V) : BindVal(V) {} + + template <typename MatchContext> bool match(const MatchContext &, SDValue N) { + // The logics here are similar to that in + // SelectionDAG::isConstantIntBuildVectorOrConstantInt, but the latter also + // treats GlobalAddressSDNode as a constant, which is difficult to turn into + // APInt. + if (auto *C = dyn_cast_or_null<ConstantSDNode>(N.getNode())) { + if (BindVal) + *BindVal = C->getAPIntValue(); + return true; + } + + APInt Discard; + return ISD::isConstantSplatVector(N.getNode(), + BindVal ? *BindVal : Discard); + } +}; +/// Match any interger constants or splat of an integer constant. +inline ConstantInt_match m_ConstInt() { return ConstantInt_match(nullptr); } +/// Match any interger constants or splat of an integer constant; return the +/// specific constant or constant splat value. +inline ConstantInt_match m_ConstInt(APInt &V) { return ConstantInt_match(&V); } + +struct SpecificInt_match { + APInt IntVal; + + explicit SpecificInt_match(APInt APV) : IntVal(std::move(APV)) {} + + template <typename MatchContext> + bool match(const MatchContext &Ctx, SDValue N) { + APInt ConstInt; + if (sd_context_match(N, Ctx, m_ConstInt(ConstInt))) + return APInt::isSameValue(IntVal, ConstInt); + return false; + } +}; + +/// Match a specific integer constant or constant splat value. +inline SpecificInt_match m_SpecificInt(APInt V) { + return SpecificInt_match(std::move(V)); +} +inline SpecificInt_match m_SpecificInt(uint64_t V) { + return SpecificInt_match(APInt(64, V)); +} + +inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); } +inline SpecificInt_match m_AllOnes() { return m_SpecificInt(~0U); } + +/// Match true boolean value based on the information provided by +/// TargetLowering. +inline auto m_True() { + return TLI_pred_match{ + [](const TargetLowering &TLI, SDValue N) { + APInt ConstVal; + if (sd_match(N, m_ConstInt(ConstVal))) + switch (TLI.getBooleanContents(N.getValueType())) { + case TargetLowering::ZeroOrOneBooleanContent: + return ConstVal.isOne(); + case TargetLowering::ZeroOrNegativeOneBooleanContent: + return ConstVal.isAllOnes(); + case TargetLowering::UndefinedBooleanContent: + return (ConstVal & 0x01) == 1; + } + + return false; + }, + m_Value()}; +} +/// Match false boolean value based on the information provided by +/// TargetLowering. +inline auto m_False() { + return TLI_pred_match{ + [](const TargetLowering &TLI, SDValue N) { + APInt ConstVal; + if (sd_match(N, m_ConstInt(ConstVal))) + switch (TLI.getBooleanContents(N.getValueType())) { + case TargetLowering::ZeroOrOneBooleanContent: + case TargetLowering::ZeroOrNegativeOneBooleanContent: + return ConstVal.isZero(); + case TargetLowering::UndefinedBooleanContent: + return (ConstVal & 0x01) == 0; + } + + return false; + }, + m_Value()}; +} +} // namespace SDPatternMatch +} // namespace llvm +#endif diff --git a/llvm/unittests/CodeGen/CMakeLists.txt b/llvm/unittests/CodeGen/CMakeLists.txt index 6140e0d6fb37..dbbacdd95ec9 100644 --- a/llvm/unittests/CodeGen/CMakeLists.txt +++ b/llvm/unittests/CodeGen/CMakeLists.txt @@ -40,6 +40,7 @@ add_llvm_unittest(CodeGenTests ScalableVectorMVTsTest.cpp SchedBoundary.cpp SelectionDAGAddressAnalysisTest.cpp + SelectionDAGPatternMatchTest.cpp TypeTraitsTest.cpp TargetOptionsTest.cpp TestAsmPrinter.cpp diff --git a/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp new file mode 100644 index 000000000000..17fc3ce8af26 --- /dev/null +++ b/llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp @@ -0,0 +1,292 @@ +//===---- llvm/unittest/CodeGen/SelectionDAGPatternMatchTest.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/CodeGen/MachineModuleInfo.h" +#include "llvm/CodeGen/SDPatternMatch.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" +#include "gtest/gtest.h" + +using namespace llvm; + +class SelectionDAGPatternMatchTest : public testing::Test { +protected: + static void SetUpTestCase() { + InitializeAllTargets(); + InitializeAllTargetMCs(); + } + + void SetUp() override { + StringRef Assembly = "@g = global i32 0\n" + "@g_alias = alias i32, i32* @g\n" + "define i32 @f() {\n" + " %1 = load i32, i32* @g\n" + " ret i32 %1\n" + "}"; + + Triple TargetTriple("riscv64--"); + std::string Error; + const Target *T = TargetRegistry::lookupTarget("", TargetTriple, Error); + // FIXME: These tests do not depend on RISCV specifically, but we have to + // initialize a target. A skeleton Target for unittests would allow us to + // always run these tests. + if (!T) + GTEST_SKIP(); + + TargetOptions Options; + TM = std::unique_ptr<LLVMTargetMachine>(static_cast<LLVMTargetMachine *>( + T->createTargetMachine("riscv64", "", "+m,+f,+d,+v", Options, + std::nullopt, std::nullopt, + CodeGenOptLevel::Aggressive))); + if (!TM) + GTEST_SKIP(); + + SMDiagnostic SMError; + M = parseAssemblyString(Assembly, SMError, Context); + if (!M) + report_fatal_error(SMError.getMessage()); + M->setDataLayout(TM->createDataLayout()); + + F = M->getFunction("f"); + if (!F) + report_fatal_error("F?"); + G = M->getGlobalVariable("g"); + if (!G) + report_fatal_error("G?"); + AliasedG = M->getNamedAlias("g_alias"); + if (!AliasedG) + report_fatal_error("AliasedG?"); + + MachineModuleInfo MMI(TM.get()); + + MF = std::make_unique<MachineFunction>(*F, *TM, *TM->getSubtargetImpl(*F), + 0, MMI); + + DAG = std::make_unique<SelectionDAG>(*TM, CodeGenOptLevel::None); + if (!DAG) + report_fatal_error("DAG?"); + OptimizationRemarkEmitter ORE(F); + DAG->init(*MF, ORE, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); + } + + TargetLoweringBase::LegalizeTypeAction getTypeAction(EVT VT) { + return DAG->getTargetLoweringInfo().getTypeAction(Context, VT); + } + + EVT getTypeToTransformTo(EVT VT) { + return DAG->getTargetLoweringInfo().getTypeToTransformTo(Context, VT); + } + + LLVMContext Context; + std::unique_ptr<LLVMTargetMachine> TM; + std::unique_ptr<Module> M; + Function *F; + GlobalVariable *G; + GlobalAlias *AliasedG; + std::unique_ptr<MachineFunction> MF; + std::unique_ptr<SelectionDAG> DAG; +}; + +TEST_F(SelectionDAGPatternMatchTest, matchValueType) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto Float32VT = EVT::getFloatingPointVT(32); + auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Float32VT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(Op0, m_SpecificVT(Int32VT))); + EVT BindVT; + EXPECT_TRUE(sd_match(Op1, m_VT(BindVT))); + EXPECT_EQ(BindVT, Float32VT); + EXPECT_TRUE(sd_match(Op0, m_IntegerVT())); + EXPECT_TRUE(sd_match(Op1, m_FloatingPointVT())); + EXPECT_TRUE(sd_match(Op2, m_VectorVT())); + EXPECT_FALSE(sd_match(Op2, m_ScalableVectorVT())); +} + +TEST_F(SelectionDAGPatternMatchTest, matchBinaryOp) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto Float32VT = EVT::getFloatingPointVT(32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); + SDValue Op2 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, Float32VT); + + SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); + SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); + SDValue Mul = DAG->getNode(ISD::MUL, DL, Int32VT, Add, Sub); + + SDValue SFAdd = DAG->getNode(ISD::STRICT_FADD, DL, {Float32VT, MVT::Other}, + {DAG->getEntryNode(), Op2, Op2}); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(Sub, m_BinOp(ISD::SUB, m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(Sub, m_Sub(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(Add, m_c_BinOp(ISD::ADD, m_Value(), m_Value()))); + EXPECT_TRUE(sd_match(Add, m_Add(m_Value(), m_Value()))); + EXPECT_TRUE(sd_match( + Mul, m_Mul(m_OneUse(m_Opc(ISD::SUB)), m_NUses<2>(m_Specific(Add))))); + EXPECT_TRUE( + sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_SpecificVT(Float32VT), + m_SpecificVT(Float32VT)))); + SDValue BindVal; + EXPECT_TRUE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_Value(BindVal), + m_Deferred(BindVal)))); + EXPECT_FALSE(sd_match(SFAdd, m_ChainedBinOp(ISD::STRICT_FADD, m_OtherVT(), + m_SpecificVT(Float32VT)))); +} + +TEST_F(SelectionDAGPatternMatchTest, matchUnaryOp) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto Int64VT = EVT::getIntegerVT(Context, 64); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT); + + SDValue ZExt = DAG->getNode(ISD::ZERO_EXTEND, DL, Int64VT, Op0); + SDValue SExt = DAG->getNode(ISD::SIGN_EXTEND, DL, Int64VT, Op0); + SDValue Trunc = DAG->getNode(ISD::TRUNCATE, DL, Int32VT, Op1); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(ZExt, m_UnaryOp(ISD::ZERO_EXTEND, m_Value()))); + EXPECT_TRUE(sd_match(SExt, m_SExt(m_Value()))); + EXPECT_TRUE(sd_match(Trunc, m_Trunc(m_Specific(Op1)))); +} + +TEST_F(SelectionDAGPatternMatchTest, matchConstants) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + + SDValue Arg0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + + SDValue Const3 = DAG->getConstant(3, DL, Int32VT); + SDValue Const87 = DAG->getConstant(87, DL, Int32VT); + SDValue Splat = DAG->getSplat(VInt32VT, DL, Arg0); + SDValue ConstSplat = DAG->getSplat(VInt32VT, DL, Const3); + SDValue Zero = DAG->getConstant(0, DL, Int32VT); + SDValue One = DAG->getConstant(1, DL, Int32VT); + SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, Int32VT); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(Const87, m_ConstInt())); + EXPECT_FALSE(sd_match(Arg0, m_ConstInt())); + APInt ConstVal; + EXPECT_TRUE(sd_match(ConstSplat, m_ConstInt(ConstVal))); + EXPECT_EQ(ConstVal, 3); + EXPECT_FALSE(sd_match(Splat, m_ConstInt())); + + EXPECT_TRUE(sd_match(Const87, m_SpecificInt(87))); + EXPECT_TRUE(sd_match(Const3, m_SpecificInt(ConstVal))); + EXPECT_TRUE(sd_match(AllOnes, m_AllOnes())); + + EXPECT_TRUE(sd_match(Zero, DAG.get(), m_False())); + EXPECT_TRUE(sd_match(One, DAG.get(), m_True())); + EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_True())); +} + +TEST_F(SelectionDAGPatternMatchTest, patternCombinators) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); + + SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); + SDValue Sub = DAG->getNode(ISD::SUB, DL, Int32VT, Add, Op0); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match( + Sub, m_AnyOf(m_Opc(ISD::ADD), m_Opc(ISD::SUB), m_Opc(ISD::MUL)))); + EXPECT_TRUE(sd_match(Add, m_AllOf(m_Opc(ISD::ADD), m_OneUse()))); +} + +TEST_F(SelectionDAGPatternMatchTest, matchNode) { + SDLoc DL; + auto Int32VT = EVT::getIntegerVT(Context, 32); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int32VT); + + SDValue Add = DAG->getNode(ISD::ADD, DL, Int32VT, Op0, Op1); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(Add, m_Node(ISD::SUB, m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_Value()))); + EXPECT_FALSE( + sd_match(Add, m_Node(ISD::ADD, m_Value(), m_Value(), m_Value()))); + EXPECT_FALSE(sd_match(Add, m_Node(ISD::ADD, m_ConstInt(), m_Value()))); +} + +namespace { +struct VPMatchContext : public SDPatternMatch::BasicMatchContext { + using SDPatternMatch::BasicMatchContext::BasicMatchContext; + + bool match(SDValue OpVal, unsigned Opc) const { + if (!OpVal->isVPOpcode()) + return OpVal->getOpcode() == Opc; + + auto BaseOpc = ISD::getBaseOpcodeForVP(OpVal->getOpcode(), false); + return BaseOpc.has_value() && *BaseOpc == Opc; + } +}; +} // anonymous namespace +TEST_F(SelectionDAGPatternMatchTest, matchContext) { + SDLoc DL; + auto BoolVT = EVT::getIntegerVT(Context, 1); + auto Int32VT = EVT::getIntegerVT(Context, 32); + auto VInt32VT = EVT::getVectorVT(Context, Int32VT, 4); + auto MaskVT = EVT::getVectorVT(Context, BoolVT, 4); + + SDValue Scalar0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int32VT); + SDValue Vector0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, VInt32VT); + SDValue Mask0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 3, MaskVT); + + SDValue VPAdd = DAG->getNode(ISD::VP_ADD, DL, VInt32VT, + {Vector0, Vector0, Mask0, Scalar0}); + SDValue VPReduceAdd = DAG->getNode(ISD::VP_REDUCE_ADD, DL, Int32VT, + {Scalar0, VPAdd, Mask0, Scalar0}); + + using namespace SDPatternMatch; + VPMatchContext VPCtx(DAG.get()); + EXPECT_TRUE(sd_context_match(VPAdd, VPCtx, m_Opc(ISD::ADD))); + // VP_REDUCE_ADD doesn't have a based opcode, so we use a normal + // sd_match before switching to VPMatchContext when checking VPAdd. + EXPECT_TRUE(sd_match(VPReduceAdd, m_Node(ISD::VP_REDUCE_ADD, m_Value(), + m_Context(VPCtx, m_Opc(ISD::ADD)), + m_Value(), m_Value()))); +} + +TEST_F(SelectionDAGPatternMatchTest, matchAdvancedProperties) { + SDLoc DL; + auto Int16VT = EVT::getIntegerVT(Context, 16); + auto Int64VT = EVT::getIntegerVT(Context, 64); + + SDValue Op0 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 1, Int64VT); + SDValue Op1 = DAG->getCopyFromReg(DAG->getEntryNode(), DL, 2, Int16VT); + + SDValue Add = DAG->getNode(ISD::ADD, DL, Int64VT, Op0, Op0); + + using namespace SDPatternMatch; + EXPECT_TRUE(sd_match(Op0, DAG.get(), m_LegalType(m_Value()))); + EXPECT_FALSE(sd_match(Op1, DAG.get(), m_LegalType(m_Value()))); + EXPECT_TRUE(sd_match(Add, DAG.get(), + m_LegalOp(m_IntegerVT(m_Add(m_Value(), m_Value()))))); +} |