//===- VPlanPatternMatch.h - Match on VPValues and recipes ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file provides a simple and efficient mechanism for performing general // tree-based pattern matches on the VPlan values and recipes, based on // LLVM's IR pattern matchers. // // Currently it provides generic matchers for unary and binary VPInstructions, // and specialized matchers like m_Not, m_ActiveLaneMask, m_BranchOnCond, // m_BranchOnCount to match specific VPInstructions. // TODO: Add missing matchers for additional opcodes and recipes as needed. // //===----------------------------------------------------------------------===// #ifndef LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H #define LLVM_TRANSFORM_VECTORIZE_VPLANPATTERNMATCH_H #include "VPlan.h" namespace llvm { namespace VPlanPatternMatch { template bool match(Val *V, const Pattern &P) { return P.match(V); } template bool match(VPUser *U, const Pattern &P) { auto *R = dyn_cast(U); return R && match(R, P); } template struct class_match { template bool match(ITy *V) const { return isa(V); } }; /// Match an arbitrary VPValue and ignore it. inline class_match m_VPValue() { return class_match(); } template struct bind_ty { Class *&VR; bind_ty(Class *&V) : VR(V) {} template bool match(ITy *V) const { if (auto *CV = dyn_cast(V)) { VR = CV; return true; } return false; } }; /// Match a specified VPValue. struct specificval_ty { const VPValue *Val; specificval_ty(const VPValue *V) : Val(V) {} bool match(VPValue *VPV) const { return VPV == Val; } }; inline specificval_ty m_Specific(const VPValue *VPV) { return VPV; } /// Stores a reference to the VPValue *, not the VPValue * itself, /// thus can be used in commutative matchers. struct deferredval_ty { VPValue *const &Val; deferredval_ty(VPValue *const &V) : Val(V) {} bool match(VPValue *const V) const { return V == Val; } }; /// Like m_Specific(), but works if the specific value to match is determined /// as part of the same match() expression. For example: /// m_Mul(m_VPValue(X), m_Specific(X)) is incorrect, because m_Specific() will /// bind X before the pattern match starts. /// m_Mul(m_VPValue(X), m_Deferred(X)) is correct, and will check against /// whichever value m_VPValue(X) populated. inline deferredval_ty m_Deferred(VPValue *const &V) { return V; } /// Match an integer constant or vector of constants if Pred::isValue returns /// true for the APInt. \p BitWidth optionally specifies the bitwidth the /// matched constant must have. If it is 0, the matched constant can have any /// bitwidth. template struct int_pred_ty { Pred P; int_pred_ty(Pred P) : P(std::move(P)) {} int_pred_ty() : P() {} bool match(VPValue *VPV) const { if (!VPV->isLiveIn()) return false; Value *V = VPV->getLiveInIRValue(); if (!V) return false; const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) CI = dyn_cast_or_null( C->getSplatValue(/*AllowPoison=*/false)); if (!CI) return false; if (BitWidth != 0 && CI->getBitWidth() != BitWidth) return false; return P.isValue(CI->getValue()); } }; /// Match a specified integer value or vector of all elements of that /// value. \p BitWidth optionally specifies the bitwidth the matched constant /// must have. If it is 0, the matched constant can have any bitwidth. struct is_specific_int { APInt Val; is_specific_int(APInt Val) : Val(std::move(Val)) {} bool isValue(const APInt &C) const { return APInt::isSameValue(Val, C); } }; template using specific_intval = int_pred_ty; inline specific_intval<0> m_SpecificInt(uint64_t V) { return specific_intval<0>(is_specific_int(APInt(64, V))); } inline specific_intval<1> m_False() { return specific_intval<1>(is_specific_int(APInt(64, 0))); } inline specific_intval<1> m_True() { return specific_intval<1>(is_specific_int(APInt(64, 1))); } struct is_all_ones { bool isValue(const APInt &C) const { return C.isAllOnes(); } }; /// Match an integer or vector with all bits set. /// For vectors, this includes constants with undefined elements. inline int_pred_ty m_AllOnes() { return int_pred_ty(); } /// Matching combinators template struct match_combine_or { LTy L; RTy R; match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) const { if (L.match(V)) return true; if (R.match(V)) return true; return false; } }; template struct match_combine_and { LTy L; RTy R; match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) const { return L.match(V) && R.match(V); } }; /// Combine two pattern matchers matching L || R template inline match_combine_or m_CombineOr(const LTy &L, const RTy &R) { return match_combine_or(L, R); } /// Combine two pattern matchers matching L && R template inline match_combine_and m_CombineAnd(const LTy &L, const RTy &R) { return match_combine_and(L, R); } /// Match a VPValue, capturing it if we match. inline bind_ty m_VPValue(VPValue *&V) { return V; } /// Match a VPInstruction, capturing if we match. inline bind_ty m_VPInstruction(VPInstruction *&V) { return V; } template struct Recipe_match { Ops_t Ops; Recipe_match() : Ops() { static_assert(std::tuple_size::value == 0 && "constructor can only be used with zero operands"); } Recipe_match(Ops_t Ops) : Ops(Ops) {} template Recipe_match(A_t A, B_t B) : Ops({A, B}) { static_assert(std::tuple_size::value == 2 && "constructor can only be used for binary matcher"); } bool match(const VPValue *V) const { auto *DefR = V->getDefiningRecipe(); return DefR && match(DefR); } bool match(const VPSingleDefRecipe *R) const { return match(static_cast(R)); } bool match(const VPRecipeBase *R) const { if (std::tuple_size::value == 0) { assert(Opcode == VPInstruction::BuildVector && "can only match BuildVector with empty ops"); auto *VPI = dyn_cast(R); return VPI && VPI->getOpcode() == VPInstruction::BuildVector; } if ((!matchRecipeAndOpcode(R) && ...)) return false; assert(R->getNumOperands() == std::tuple_size::value && "recipe with matched opcode does not have the expected number of " "operands"); auto IdxSeq = std::make_index_sequence::value>(); if (all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) { return Op.match(R->getOperand(Idx)); })) return true; return Commutative && all_of_tuple_elements(IdxSeq, [R](auto Op, unsigned Idx) { return Op.match(R->getOperand(R->getNumOperands() - Idx - 1)); }); } private: template static bool matchRecipeAndOpcode(const VPRecipeBase *R) { auto *DefR = dyn_cast(R); // Check for recipes that do not have opcodes. if constexpr (std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value) return DefR; else return DefR && DefR->getOpcode() == Opcode; } /// Helper to check if predicate \p P holds on all tuple elements in Ops using /// the provided index sequence. template bool all_of_tuple_elements(std::index_sequence, Fn P) const { return (P(std::get(Ops), Is) && ...); } }; template using ZeroOpRecipe_match = Recipe_match, Opcode, false, RecipeTys...>; template using UnaryRecipe_match = Recipe_match, Opcode, false, RecipeTys...>; template using UnaryVPInstruction_match = UnaryRecipe_match; template using ZeroOpVPInstruction_match = ZeroOpRecipe_match; template using AllUnaryRecipe_match = UnaryRecipe_match; template using BinaryRecipe_match = Recipe_match, Opcode, Commutative, RecipeTys...>; template using BinaryVPInstruction_match = BinaryRecipe_match; template using TernaryRecipe_match = Recipe_match, Opcode, Commutative, RecipeTys...>; template using TernaryVPInstruction_match = TernaryRecipe_match; template using AllBinaryRecipe_match = BinaryRecipe_match; /// BuildVector is matches only its opcode, w/o matching its operands as the /// number of operands is not fixed. inline ZeroOpVPInstruction_match m_BuildVector() { return ZeroOpVPInstruction_match(); } template inline UnaryVPInstruction_match m_VPInstruction(const Op0_t &Op0) { return UnaryVPInstruction_match(Op0); } template inline BinaryVPInstruction_match m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1) { return BinaryVPInstruction_match(Op0, Op1); } template inline TernaryVPInstruction_match m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return TernaryVPInstruction_match( {Op0, Op1, Op2}); } template using Recipe4Op_match = Recipe_match, Opcode, Commutative, RecipeTys...>; template using VPInstruction4Op_match = Recipe4Op_match; template inline VPInstruction4Op_match m_VPInstruction(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2, const Op3_t &Op3) { return VPInstruction4Op_match( {Op0, Op1, Op2, Op3}); } template inline UnaryVPInstruction_match m_Freeze(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline UnaryVPInstruction_match m_BranchOnCond(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline UnaryVPInstruction_match m_Broadcast(const Op0_t &Op0) { return m_VPInstruction(Op0); } template inline BinaryVPInstruction_match m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } template inline BinaryVPInstruction_match m_BranchOnCount(const Op0_t &Op0, const Op1_t &Op1) { return m_VPInstruction(Op0, Op1); } template inline AllUnaryRecipe_match m_Unary(const Op0_t &Op0) { return AllUnaryRecipe_match(Op0); } template inline AllUnaryRecipe_match m_Trunc(const Op0_t &Op0) { return m_Unary(Op0); } template inline AllUnaryRecipe_match m_ZExt(const Op0_t &Op0) { return m_Unary(Op0); } template inline AllUnaryRecipe_match m_SExt(const Op0_t &Op0) { return m_Unary(Op0); } template inline match_combine_or, AllUnaryRecipe_match> m_ZExtOrSExt(const Op0_t &Op0) { return m_CombineOr(m_ZExt(Op0), m_SExt(Op0)); } template inline AllBinaryRecipe_match m_Binary(const Op0_t &Op0, const Op1_t &Op1) { return AllBinaryRecipe_match(Op0, Op1); } template inline AllBinaryRecipe_match m_c_Binary(const Op0_t &Op0, const Op1_t &Op1) { return AllBinaryRecipe_match(Op0, Op1); } template inline AllBinaryRecipe_match m_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllBinaryRecipe_match m_c_Mul(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } /// Match a binary OR operation. Note that while conceptually the operands can /// be matched commutatively, \p Commutative defaults to false in line with the /// IR-based pattern matching infrastructure. Use m_c_BinaryOr for a commutative /// version of the matcher. template inline AllBinaryRecipe_match m_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_Binary(Op0, Op1); } template inline AllBinaryRecipe_match m_c_BinaryOr(const Op0_t &Op0, const Op1_t &Op1) { return m_BinaryOr(Op0, Op1); } template using GEPLikeRecipe_match = BinaryRecipe_match; template inline GEPLikeRecipe_match m_GetElementPtr(const Op0_t &Op0, const Op1_t &Op1) { return GEPLikeRecipe_match(Op0, Op1); } template using AllTernaryRecipe_match = Recipe_match, Opcode, false, VPReplicateRecipe, VPInstruction, VPWidenSelectRecipe>; template inline AllTernaryRecipe_match m_Select(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return AllTernaryRecipe_match( {Op0, Op1, Op2}); } template inline match_combine_or, AllBinaryRecipe_match, Op0_t, Instruction::Xor, true>> m_Not(const Op0_t &Op0) { return m_CombineOr(m_VPInstruction(Op0), m_c_Binary(m_AllOnes(), Op0)); } template inline match_combine_or< BinaryVPInstruction_match, AllTernaryRecipe_match, Instruction::Select>> m_LogicalAnd(const Op0_t &Op0, const Op1_t &Op1) { return m_CombineOr( m_VPInstruction(Op0, Op1), m_Select(Op0, Op1, m_False())); } template inline AllTernaryRecipe_match, Op1_t, Instruction::Select> m_LogicalOr(const Op0_t &Op0, const Op1_t &Op1) { return m_Select(Op0, m_True(), Op1); } template using VPScalarIVSteps_match = TernaryRecipe_match; template inline VPScalarIVSteps_match m_ScalarIVSteps(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return VPScalarIVSteps_match({Op0, Op1, Op2}); } template using VPDerivedIV_match = Recipe_match, 0, false, VPDerivedIVRecipe>; template inline VPDerivedIV_match m_DerivedIV(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) { return VPDerivedIV_match({Op0, Op1, Op2}); } /// Match a call argument at a given argument index. template struct Argument_match { /// Call argument index to match. unsigned OpI; Opnd_t Val; Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} template bool match(OpTy *V) const { if (const auto *R = dyn_cast(V)) return Val.match(R->getOperand(OpI)); if (const auto *R = dyn_cast(V)) return Val.match(R->getOperand(OpI)); if (const auto *R = dyn_cast(V)) if (isa(R->getUnderlyingInstr())) return Val.match(R->getOperand(OpI + 1)); return false; } }; /// Match a call argument. template inline Argument_match m_Argument(const Opnd_t &Op) { return Argument_match(OpI, Op); } /// Intrinsic matchers. struct IntrinsicID_match { unsigned ID; IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} template bool match(OpTy *V) const { if (const auto *R = dyn_cast(V)) return R->getVectorIntrinsicID() == ID; if (const auto *R = dyn_cast(V)) return R->getCalledScalarFunction()->getIntrinsicID() == ID; if (const auto *R = dyn_cast(V)) if (const auto *CI = dyn_cast(R->getUnderlyingInstr())) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; } }; /// Intrinsic matches are combinations of ID matchers, and argument /// matchers. Higher arity matcher are defined recursively in terms of and-ing /// them with lower arity matchers. Here's some convenient typedefs for up to /// several arguments, and more can be added as needed template struct m_Intrinsic_Ty; template struct m_Intrinsic_Ty { using Ty = match_combine_and>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; }; /// Match intrinsic calls like this: /// m_Intrinsic(m_VPValue(X), ...) template inline IntrinsicID_match m_Intrinsic() { return IntrinsicID_match(IntrID); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0) { return m_CombineAnd(m_Intrinsic(), m_Argument<0>(Op0)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1) { return m_CombineAnd(m_Intrinsic(Op0), m_Argument<1>(Op1)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2) { return m_CombineAnd(m_Intrinsic(Op0, Op1), m_Argument<2>(Op2)); } template inline typename m_Intrinsic_Ty::Ty m_Intrinsic(const T0 &Op0, const T1 &Op1, const T2 &Op2, const T3 &Op3) { return m_CombineAnd(m_Intrinsic(Op0, Op1, Op2), m_Argument<3>(Op3)); } } // namespace VPlanPatternMatch } // namespace llvm #endif