//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file provides a simple and efficient mechanism for performing general // tree-based pattern matches on the VPlan values and recipes, based on // LLVM's IR pattern matchers. // //===----------------------------------------------------------------------===// #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H #include "VPlan.h" namespace llvm::VPlanPatternMatch { template bool match(Val *V, const Pattern &P) { return P.match(V); } template bool match(VPUser *U, const Pattern &P) { auto *R = dyn_cast(U); return R && match(R, P); } template bool match(VPSingleDefRecipe *R, const Pattern &P) { return P.match(static_cast(R)); } template struct VPMatchFunctor { const Pattern &P; VPMatchFunctor(const Pattern &P) : P(P) {} bool operator()(Val *V) const { return match(V, P); } }; /// A match functor that can be used as a UnaryPredicate in functional /// algorithms like all_of. template VPMatchFunctor match_fn(const Pattern &P) { return P; } template struct class_match { template bool match(ITy *V) const { return isa(V); } }; /// Match an arbitrary VPValue and ignore it. inline class_match m_VPValue() { return class_match(); } template struct bind_ty { Class *&VR; bind_ty(Class *&V) : VR(V) {} template bool match(ITy *V) const { if (auto *CV = dyn_cast(V)) { VR = CV; return true; } return false; } }; /// Match a specified VPValue. struct specificval_ty { const VPValue *Val; specificval_ty(const VPValue *V) : Val(V) {} bool match(VPValue *VPV) const { return VPV == Val; } }; inline specificval_ty m_Specific(const VPValue *VPV) { return VPV; } /// Stores a reference to the VPValue *, not the VPValue * itself, /// thus can be used in commutative matchers. struct deferredval_ty { VPValue *const &Val; deferredval_ty(VPValue *const &V) : Val(V) {} bool match(VPValue *const V) const { return V == Val; } }; /// Like m_Specific(), but works if the specific value to match is determined /// as part of the same match() expression. For example: /// m_Mul(m_VPValue(X), m_Specific(X)) is incorrect, because m_Specific() will /// bind X before the pattern match starts. /// m_Mul(m_VPValue(X), m_Deferred(X)) is correct, and will check against /// whichever value m_VPValue(X) populated. inline deferredval_ty m_Deferred(VPValue *const &V) { return V; } /// Match an integer constant or vector of constants if Pred::isValue returns /// true for the APInt. \p BitWidth optionally specifies the bitwidth the /// matched constant must have. If it is 0, the matched constant can have any /// bitwidth. template struct int_pred_ty { Pred P; int_pred_ty(Pred P) : P(std::move(P)) {} int_pred_ty() : P() {} bool match(VPValue *VPV) const { if (!VPV->isLiveIn()) return false; Value *V = VPV->getLiveInIRValue(); if (!V) return false; assert(!V->getType()->isVectorTy() && "Unexpected vector live-in"); const auto *CI = dyn_cast(V); if (!CI) return false; if (BitWidth != 0 && CI->getBitWidth() != BitWidth) return false; return P.isValue(CI->getValue()); } }; /// Match a specified integer value or vector of all elements of that /// value. \p BitWidth optionally specifies the bitwidth the matched constant /// must have. If it is 0, the matched constant can have any bitwidth. struct is_specific_int { APInt Val; is_specific_int(APInt Val) : Val(std::move(Val)) {} bool isValue(const APInt &C) const { return APInt::isSameValue(Val, C); } }; template using specific_intval = int_pred_ty; inline specific_intval<0> m_SpecificInt(uint64_t V) { return specific_intval<0>(is_specific_int(APInt(64, V))); } inline specific_intval<1> m_False() { return specific_intval<1>(is_specific_int(APInt(64, 0))); } inline specific_intval<1> m_True() { return specific_intval<1>(is_specific_int(APInt(64, 1))); } struct is_all_ones { bool isValue(const APInt &C) const { return C.isAllOnes(); } }; /// Match an integer or vector with all bits set. /// For vectors, this includes constants with undefined elements. inline int_pred_ty m_AllOnes() { return int_pred_ty(); } struct is_zero_int { bool isValue(const APInt &C) const { return C.isZero(); } }; struct is_one { bool isValue(const APInt &C) const { return C.isOne(); } }; /// Match an integer 0 or a vector with all elements equal to 0. /// For vectors, this includes constants with undefined elements. inline int_pred_ty m_ZeroInt() { return int_pred_ty(); } /// Match an integer 1 or a vector with all elements equal to 1. /// For vectors, this includes constants with undefined elements. inline int_pred_ty m_One() { return int_pred_ty(); } struct bind_apint { const APInt *&Res; bind_apint(const APInt *&Res) : Res(Res) {} bool match(VPValue *VPV) const { if (!VPV->isLiveIn()) return false; Value *V = VPV->getLiveInIRValue(); if (!V) return false; assert(!V->getType()->isVectorTy() && "Unexpected vector live-in"); const auto *CI = dyn_cast(V); if (!CI) return false; Res = &CI->getValue(); return true; } }; inline bind_apint m_APInt(const APInt *&C) { return C; } struct bind_const_int { uint64_t &Res; bind_const_int(uint64_t &Res) : Res(Res) {} bool match(VPValue *VPV) const { const APInt *APConst; if (!bind_apint(APConst).match(VPV)) return false; if (auto C = APConst->tryZExtValue()) { Res = *C; return true; } return false; } }; /// Match a plain integer constant no wider than 64-bits, capturing it if we /// match. inline bind_const_int m_ConstantInt(uint64_t &C) { return C; } /// Matching combinators template struct match_combine_or { LTy L; RTy R; match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) const { return L.match(V) || R.match(V); } }; template struct match_combine_and { LTy L; RTy R; match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) const { return L.match(V) && R.match(V); } }; /// Combine two pattern matchers matching L || R template inline match_combine_or m_CombineOr(const LTy &L, const RTy &R) { return match_combine_or(L, R); } /// Combine two pattern matchers matching L && R template inline match_combine_and m_CombineAnd(const LTy &L, const RTy &R) { return match_combine_and(L, R); } /// Match a VPValue, capturing it if we match. inline bind_ty m_VPValue(VPValue *&V) { return V; } /// Match a VPInstruction, capturing if we match. inline bind_ty m_VPInstruction(VPInstruction *&V) { return V; } template struct Recipe_match { Ops_t Ops; template Recipe_match(OpTy... Ops) : Ops(Ops...) { static_assert(std::tuple_size::value == sizeof...(Ops) && "number of operands in constructor doesn't match Ops_t"); static_assert((!Commutative || std::tuple_size::value == 2) && "only binary ops can be commutative"); } bool match(const VPValue *V) const { auto *DefR = V->getDefiningRecipe(); return DefR && match(DefR); } bool match(const VPSingleDefRecipe *R) const { return match(static_cast(R)); } bool match(const VPRecipeBase *R) const { if (std::tuple_size_v == 0) { auto *VPI = dyn_cast(R); return VPI && VPI->getOpcode() == Opcode; } if ((!matchRecipeAndOpcode(R) && ...)) return false; if (R->getNumOperands() != std::tuple_size::value) { assert(Opcode == Instruction::PHI && "non-variadic recipe with matched opcode does not have the " "expected number of operands"); return false; } auto IdxSeq = std::make_index_sequence::value>(); if (all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) { return Op.match(R->getOperand(Idx)); })) return true; return Commutative && all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) { return Op.match(R->getOperand(R->getNumOperands() - Idx - 1)); }); } private: template static bool matchRecipeAndOpcode(const VPRecipeBase *R) { auto *DefR = dyn_cast(R); // Check for recipes that do not have opcodes. if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) return DefR; else return DefR && DefR->getOpcode() == Opcode; } /// Helper to check if predicate \p P holds on all tuple elements in Ops using /// the provided index sequence. template bool all_of_tuple_elements(std::index_sequence, Fn P) const { return (P(std::get(Ops), Is) && ...); } }; template using AllRecipe_match = Recipe_match, Opcode, /*Commutative*/ false, VPWidenRecipe, VPReplicateRecipe, VPWidenCastRecipe, VPInstruction, VPWidenSelectRecipe>; template using AllRecipe_commutative_match = Recipe_match, Opcode, /*Commutative*/ true, VPWidenRecipe, VPReplicateRecipe, VPInstruction>; template using VPInstruction_match = Recipe_match, Opcode, /*Commutative*/ false, VPInstruction>; template inline VPInstruction_match m_VPInstruction(const OpTys &...Ops) { return VPInstruction_match(Ops...); } /// BuildVector is matches only its opcode, w/o matching its operands as the /// number of operands is not fixed. inline VPInstruction_match m_BuildVector() { return m_VPInstruction(); } template inline VPInstruction_match m_Freeze(const Op0_t &Op0) { return m_VPInstruction(Op0); } inline VPInstruction_match m_BranchOnCond() { return m_VPInstruction(); } template inline VPInstruction_match m_BranchOnCond(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_Broadcast(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_EVL(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_ExtractLastLane(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } template inline VPInstruction_match m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } template inline VPInstruction_match m_ExtractLastPart(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match< VPInstruction::ExtractLastLane, VPInstruction_match> m_ExtractLastLaneOfLastPart(const Op0_t &Op0) { return m_ExtractLastLane(m_ExtractLastPart(Op0)); } template inline VPInstruction_match m_ExtractPenultimateElement(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return m_VPInstruction(Op0, Op1, Op2); } inline VPInstruction_match m_BranchOnCount() { return m_VPInstruction(); } template inline VPInstruction_match m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } inline VPInstruction_match m_AnyOf() { return m_VPInstruction(); } template inline VPInstruction_match m_AnyOf(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_FirstActiveLane(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline VPInstruction_match m_LastActiveLane(const Op0_t &Op0) { return m_VPInstruction(Op0); } inline VPInstruction_match m_StepVector() { return m_VPInstruction(); } template inline AllRecipe_match m_Unary(const Op0_t &Op0) { return AllRecipe_match(Op0); } template inline AllRecipe_match m_Trunc(const Op0_t &Op0) { return m_Unary(Op0); } template inline match_combine_or, Op0_t> m_TruncOrSelf(const Op0_t &Op0) { return m_CombineOr(m_Trunc(Op0), Op0); } template inline AllRecipe_match m_ZExt(const Op0_t &Op0) { return m_Unary(Op0); } template inline AllRecipe_match m_SExt(const Op0_t &Op0) { return m_Unary(Op0); } template inline match_combine_or, AllRecipe_match> m_ZExtOrSExt(const Op0_t &Op0) { return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); } template inline match_combine_or, Op0_t> m_ZExtOrSelf(const Op0_t &Op0) { return m_CombineOr(m_ZExt(Op0), Op0); } template inline AllRecipe_match m_Binary(const Op0_t &Op0, const Op1_t &Op1) { return AllRecipe_match(Op0, Op1); } template inline AllRecipe_commutative_match m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) { return AllRecipe_commutative_match(Op0, Op1); } template inline AllRecipe_match m_Add(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllRecipe_commutative_match m_c_Add(const Op0_t &Op0, const Op1_t &Op1) { return m_c_Binary(Op0, Op1); } template inline AllRecipe_match m_Sub(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllRecipe_match m_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllRecipe_commutative_match m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_c_Binary(Op0, Op1); } /// Match a binary AND operation. template inline AllRecipe_commutative_match m_c_BinaryAnd(const Op0_t &Op0, const Op1_t &Op1) { return m_c_Binary(Op0, Op1); } /// Match a binary OR operation. Note that while conceptually the operands can /// be matched commutatively, \p Commutative defaults to false in line with the /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative /// version of the matcher. template inline AllRecipe_match m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllRecipe_commutative_match m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_c_Binary(Op0, Op1); } /// Cmp_match is a variant of BinaryRecipe_match that also binds the comparison /// predicate. Opcodes must either be Instruction::ICmp or Instruction::FCmp, or /// both. template struct Cmp_match { static_assert((sizeof...(Opcodes) == 1 || sizeof...(Opcodes) == 2) && "Expected one or two opcodes"); static_assert( ((Opcodes == Instruction::ICmp || Opcodes == Instruction::FCmp) && ...) && "Expected a compare instruction opcode"); CmpPredicate *Predicate = nullptr; Op0_t Op0; Op1_t Op1; Cmp_match(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) : Predicate(&Pred), Op0(Op0), Op1(Op1) {} Cmp_match(const Op0_t &Op0, const Op1_t &Op1) : Op0(Op0), Op1(Op1) {} bool match(const VPValue *V) const { auto *DefR = V->getDefiningRecipe(); return DefR && match(DefR); } bool match(const VPRecipeBase *V) const { if ((m_Binary(Op0, Op1).match(V) || ...)) { if (Predicate) *Predicate = cast(V)->getPredicate(); return true; } return false; } }; /// SpecificCmp_match is a variant of Cmp_match that matches the comparison /// predicate, instead of binding it. template struct SpecificCmp_match { const CmpPredicate Predicate; Op0_t Op0; Op1_t Op1; SpecificCmp_match(CmpPredicate Pred, const Op0_t &LHS, const Op1_t &RHS) : Predicate(Pred), Op0(LHS), Op1(RHS) {} bool match(const VPValue *V) const { CmpPredicate CurrentPred; return Cmp_match(CurrentPred, Op0, Op1) .match(V) && CmpPredicate::getMatching(CurrentPred, Predicate); } }; template inline Cmp_match m_ICmp(const Op0_t &Op0, const Op1_t &Op1) { return Cmp_match(Op0, Op1); } template inline Cmp_match m_ICmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) { return Cmp_match(Pred, Op0, Op1); } template inline SpecificCmp_match m_SpecificICmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) { return SpecificCmp_match(MatchPred, Op0, Op1); } template inline Cmp_match m_Cmp(const Op0_t &Op0, const Op1_t &Op1) { return Cmp_match(Op0, Op1); } template inline Cmp_match m_Cmp(CmpPredicate &Pred, const Op0_t &Op0, const Op1_t &Op1) { return Cmp_match( Pred, Op0, Op1); } template inline SpecificCmp_match m_SpecificCmp(CmpPredicate MatchPred, const Op0_t &Op0, const Op1_t &Op1) { return SpecificCmp_match( MatchPred, Op0, Op1); } template using GEPLikeRecipe_match = match_combine_or< Recipe_match, Instruction::GetElementPtr, /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>, match_combine_or< VPInstruction_match, VPInstruction_match>>; template inline GEPLikeRecipe_match m_GetElementPtr(const Op0_t &Op0, const Op1_t &Op1) { return m_CombineOr( Recipe_match, Instruction::GetElementPtr, /*Commutative*/ false, VPReplicateRecipe, VPWidenGEPRecipe>( Op0, Op1), m_CombineOr( VPInstruction_match(Op0, Op1), VPInstruction_match(Op0, Op1))); } template inline AllRecipe_match m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return AllRecipe_match( {Op0, Op1, Op2}); } template inline match_combine_or, AllRecipe_commutative_match< Instruction::Xor, int_pred_ty, Op0_t>> m_Not(const Op0_t &Op0) { return m_CombineOr(m_VPInstruction(Op0), m_c_Binary(m_AllOnes(), Op0)); } template inline match_combine_or< VPInstruction_match, AllRecipe_match>> m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { return m_CombineOr( m_VPInstruction(Op0, Op1), m_Select(Op0, Op1, m_False())); } template inline AllRecipe_match, Op1_t> m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) { return m_Select(Op0, m_True(), Op1); } template using VPScalarIVSteps_match = Recipe_match, 0, false, VPScalarIVStepsRecipe>; template inline VPScalarIVSteps_match m_ScalarIVSteps(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return VPScalarIVSteps_match({Op0, Op1, Op2}); } template using VPDerivedIV_match = Recipe_match, 0, false, VPDerivedIVRecipe>; template inline VPDerivedIV_match m_DerivedIV(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return VPDerivedIV_match({Op0, Op1, Op2}); } template struct Load_match { Addr_t Addr; Mask_t Mask; Load_match(Addr_t Addr, Mask_t Mask) : Addr(Addr), Mask(Mask) {} template bool match(const OpTy *V) const { auto *Load = dyn_cast(V); if (!Load || !Addr.match(Load->getAddr()) || !Load->isMasked() || !Mask.match(Load->getMask())) return false; return true; } }; /// Match a (possibly reversed) masked load. template inline Load_match m_MaskedLoad(const Addr_t &Addr, const Mask_t &Mask) { return Load_match(Addr, Mask); } template struct Store_match { Addr_t Addr; Val_t Val; Mask_t Mask; Store_match(Addr_t Addr, Val_t Val, Mask_t Mask) : Addr(Addr), Val(Val), Mask(Mask) {} template bool match(const OpTy *V) const { auto *Store = dyn_cast(V); if (!Store || !Addr.match(Store->getAddr()) || !Val.match(Store->getStoredValue()) || !Store->isMasked() || !Mask.match(Store->getMask())) return false; return true; } }; /// Match a (possibly reversed) masked store. template inline Store_match m_MaskedStore(const Addr_t &Addr, const Val_t &Val, const Mask_t &Mask) { return Store_match(Addr, Val, Mask); } template using VectorEndPointerRecipe_match = Recipe_match, 0, /*Commutative*/ false, VPVectorEndPointerRecipe>; template VectorEndPointerRecipe_match m_VecEndPtr(const Op0_t &Op0, const Op1_t &Op1) { return VectorEndPointerRecipe_match(Op0, Op1); } /// Match a call argument at a given argument index. template struct Argument_match { /// Call argument index to match. unsigned OpI; Opnd_t Val; Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} template bool match(OpTy *V) const { if (const auto *R = dyn_cast(V)) return Val.match(R->getOperand(OpI)); if (const auto *R = dyn_cast(V)) return Val.match(R->getOperand(OpI)); if (const auto *R = dyn_cast(V)) if (isa(R->getUnderlyingInstr())) return Val.match(R->getOperand(OpI + 1)); return false; } }; /// Match a call argument. template inline Argument_match m_Argument(const Opnd_t &Op) { return Argument_match(OpI, Op); } /// Intrinsic matchers. struct IntrinsicID_match { unsigned ID; IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} template bool match(OpTy *V) const { if (const auto *R = dyn_cast(V)) return R->getVectorIntrinsicID() == ID; if (const auto *R = dyn_cast(V)) return R->getCalledScalarFunction()->getIntrinsicID() == ID; if (const auto *R = dyn_cast(V)) if (const auto *CI = dyn_cast(R->getUnderlyingInstr())) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; } }; /// Intrinsic matches are combinations of ID matchers, and argument /// matchers. Higher arity matcher are defined recursively in terms of and-ing /// them with lower arity matchers. Here's some convenient typedefs for up to /// several arguments, and more can be added as needed template struct m_Intrinsic_Ty; template struct m_Intrinsic_Ty { using Ty = match_combine_and>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; /// Match intrinsic calls like this: /// m_Intrinsic(m_VPValue(X), ...) template inline IntrinsicID_match m_Intrinsic() { return IntrinsicID_match(IntrID); } /// Match intrinsic calls with a runtime intrinsic ID. inline IntrinsicID_match m_Intrinsic(Intrinsic::ID IntrID) { return IntrinsicID_match(IntrID); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0) { return m_CombineAnd(m_Intrinsic(), m_Argument<0>(Op0)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1) { return m_CombineAnd(m_Intrinsic(Op0), m_Argument<1>(Op1)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2) { return m_CombineAnd(m_Intrinsic(Op0, Op1), m_Argument<2>(Op2)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { return m_CombineAnd(m_Intrinsic(Op0, Op1, Op2), m_Argument<3>(Op3)); } struct live_in_vpvalue { template bool match(ITy *V) const { VPValue *Val = dyn_cast(V); return Val && Val->isLiveIn(); } }; inline live_in_vpvalue m_LiveIn() { return live_in_vpvalue(); } template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { return V->hasOneUse() && SubPattern.match(V); } }; template inline OneUse_match m_OneUse(const T &SubPattern) { return SubPattern; } } // namespace llvm::VPlanPatternMatch #endif