diff options
Diffstat (limited to 'llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h')
-rw-r--r-- | llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h index 164b46b..871028d 100644 --- a/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h +++ b/llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h @@ -182,6 +182,12 @@ m_scev_PtrToInt(const Op0_t &Op0) { return SCEVUnaryExpr_match<SCEVPtrToIntExpr, Op0_t>(Op0); } +template <typename Op0_t> +inline SCEVUnaryExpr_match<SCEVTruncateExpr, Op0_t> +m_scev_Trunc(const Op0_t &Op0) { + return m_scev_Unary<SCEVTruncateExpr>(Op0); +} + /// Match a binary SCEV. template <typename SCEVTy, typename Op0_t, typename Op1_t, SCEV::NoWrapFlags WrapFlags = SCEV::FlagAnyWrap, @@ -246,6 +252,80 @@ m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) { return m_scev_Binary<SCEVUDivExpr>(Op0, Op1); } +/// Match unsigned remainder pattern. +/// Matches patterns generated by getURemExpr. +template <typename Op0_t, typename Op1_t> struct SCEVURem_match { + Op0_t Op0; + Op1_t Op1; + ScalarEvolution &SE; + + SCEVURem_match(Op0_t Op0, Op1_t Op1, ScalarEvolution &SE) + : Op0(Op0), Op1(Op1), SE(SE) {} + + bool match(const SCEV *Expr) const { + if (Expr->getType()->isPointerTy()) + return false; + + // Try to match 'zext (trunc A to iB) to iY', which is used + // for URem with constant power-of-2 second operands. Make sure the size of + // the operand A matches the size of the whole expressions. + const SCEV *LHS; + if (SCEVPatternMatch::match(Expr, m_scev_ZExt(m_scev_Trunc(m_SCEV(LHS))))) { + Type *TruncTy = cast<SCEVZeroExtendExpr>(Expr)->getOperand()->getType(); + // Bail out if the type of the LHS is larger than the type of the + // expression for now. + if (SE.getTypeSizeInBits(LHS->getType()) > + SE.getTypeSizeInBits(Expr->getType())) + return false; + if (LHS->getType() != Expr->getType()) + LHS = SE.getZeroExtendExpr(LHS, Expr->getType()); + const SCEV *RHS = + SE.getConstant(APInt(SE.getTypeSizeInBits(Expr->getType()), 1) + << SE.getTypeSizeInBits(TruncTy)); + return Op0.match(LHS) && Op1.match(RHS); + } + const auto *Add = dyn_cast<SCEVAddExpr>(Expr); + if (Add == nullptr || Add->getNumOperands() != 2) + return false; + + const SCEV *A = Add->getOperand(1); + const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); + + if (Mul == nullptr) + return false; + + const auto MatchURemWithDivisor = [&](const SCEV *B) { + // (SomeExpr + (-(SomeExpr / B) * B)). + if (Expr == SE.getURemExpr(A, B)) + return Op0.match(A) && Op1.match(B); + return false; + }; + + // (SomeExpr + (-1 * (SomeExpr / B) * B)). + if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0))) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(2)); + + // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). + if (Mul->getNumOperands() == 2) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(0)) || + MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(1))) || + MatchURemWithDivisor(SE.getNegativeSCEV(Mul->getOperand(0))); + return false; + } +}; + +/// Match the mathematical pattern A - (A / B) * B, where A and B can be +/// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used +/// for URem with constant power-of-2 second operands. It's not always easy, as +/// A and B can be folded (imagine A is X / 2, and B is 4, A / B becomes X / 8). +template <typename Op0_t, typename Op1_t> +inline SCEVURem_match<Op0_t, Op1_t> m_scev_URem(Op0_t LHS, Op1_t RHS, + ScalarEvolution &SE) { + return SCEVURem_match<Op0_t, Op1_t>(LHS, RHS, SE); +} + inline class_match<const Loop> m_Loop() { return class_match<const Loop>(); } /// Match an affine SCEVAddRecExpr. |