diff options
-rw-r--r-- | llvm/include/llvm/Analysis/ValueTracking.h | 16 | ||||
-rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 73 | ||||
-rw-r--r-- | llvm/unittests/Analysis/ValueTrackingTest.cpp | 52 |
3 files changed, 111 insertions, 30 deletions
diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h index c804f55..15e23de 100644 --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -21,6 +21,7 @@ #include "llvm/IR/FMF.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/Support/Compiler.h" #include <cassert> @@ -965,6 +966,21 @@ LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P, Value *&Start, Value *&Step); +/// Attempt to match a simple value-accumulating recurrence of the form: +/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge] +/// %llvm.intrinsic = call Ty @llvm.intrinsic(%OtherOp, %llvm.intrinsic.acc) +/// OR +/// %llvm.intrinsic.acc = phi Ty [%Init, %Entry], [%llvm.intrinsic, %backedge] +/// %llvm.intrinsic = call Ty @llvm.intrinsic(%llvm.intrinsic.acc, %OtherOp) +/// +/// The recurrence relation is of kind: +/// X_0 = %a (initial value), +/// X_i = call @llvm.binary.intrinsic(X_i-1, %b) +/// Where %b is not required to be loop-invariant. +LLVM_ABI bool matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I, + PHINode *&P, Value *&Init, + Value *&OtherOp); + /// Return true if RHS is known to be implied true by LHS. Return false if /// RHS is known to be implied false by LHS. Otherwise, return std::nullopt if /// no implication can be made. A & B must be i1 (boolean) values or a vector of diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index 93c2221..e576f48 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -9070,46 +9070,43 @@ llvm::canConvertToMinOrMaxIntrinsic(ArrayRef<Value *> VL) { return {Intrinsic::not_intrinsic, false}; } -bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, - Value *&Start, Value *&Step) { +template <typename InstTy> +static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst, + Value *&Init, Value *&OtherOp) { // Handle the case of a simple two-predecessor recurrence PHI. // There's a lot more that could theoretically be done here, but // this is sufficient to catch some interesting cases. // TODO: Expand list -- gep, uadd.sat etc. - if (P->getNumIncomingValues() != 2) + if (PN->getNumIncomingValues() != 2) return false; - for (unsigned i = 0; i != 2; ++i) { - Value *L = P->getIncomingValue(i); - Value *R = P->getIncomingValue(!i); - auto *LU = dyn_cast<BinaryOperator>(L); - if (!LU) - continue; - Value *LL = LU->getOperand(0); - Value *LR = LU->getOperand(1); - - // Find a recurrence. - if (LL == P) - L = LR; - else if (LR == P) - L = LL; - else - continue; // Check for recurrence with L and R flipped. - - // We have matched a recurrence of the form: - // %iv = [R, %entry], [%iv.next, %backedge] - // %iv.next = binop %iv, L - // OR - // %iv = [R, %entry], [%iv.next, %backedge] - // %iv.next = binop L, %iv - BO = LU; - Start = R; - Step = L; - return true; + for (unsigned I = 0; I != 2; ++I) { + if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(I))) { + Value *LHS = Operation->getOperand(0); + Value *RHS = Operation->getOperand(1); + if (LHS != PN && RHS != PN) + continue; + + Inst = Operation; + Init = PN->getIncomingValue(!I); + OtherOp = (LHS == PN) ? RHS : LHS; + return true; + } } return false; } +bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO, + Value *&Start, Value *&Step) { + // We try to match a recurrence of the form: + // %iv = [Start, %entry], [%iv.next, %backedge] + // %iv.next = binop %iv, Step + // Or: + // %iv = [Start, %entry], [%iv.next, %backedge] + // %iv.next = binop Step, %iv + return matchTwoInputRecurrence(P, BO, Start, Step); +} + bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P, Value *&Start, Value *&Step) { BinaryOperator *BO = nullptr; @@ -9119,6 +9116,22 @@ bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P, return P && matchSimpleRecurrence(P, BO, Start, Step) && BO == I; } +bool llvm::matchSimpleBinaryIntrinsicRecurrence(const IntrinsicInst *I, + PHINode *&P, Value *&Init, + Value *&OtherOp) { + // Binary intrinsics only supported for now. + if (I->arg_size() != 2 || I->getType() != I->getArgOperand(0)->getType() || + I->getType() != I->getArgOperand(1)->getType()) + return false; + + IntrinsicInst *II = nullptr; + P = dyn_cast<PHINode>(I->getArgOperand(0)); + if (!P) + P = dyn_cast<PHINode>(I->getArgOperand(1)); + + return P && matchTwoInputRecurrence(P, II, Init, OtherOp) && II == I; +} + /// Return true if "icmp Pred LHS RHS" is always true. static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, const Value *RHS) { diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp index 6031898..dbe7228 100644 --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -1257,6 +1257,58 @@ TEST_F(ValueTrackingTest, computePtrAlignment) { EXPECT_EQ(getKnownAlignment(A, DL, CxtI3, &AC, &DT), Align(16)); } +TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceUMax) { + auto M = parseModule(R"( + define i8 @test(i8 %a, i8 %b) { + entry: + br label %loop + loop: + %iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ] + %umax.acc = phi i8 [ %umax, %loop ], [ %a, %entry ] + %umax = call i8 @llvm.umax.i8(i8 %umax.acc, i8 %b) + %iv.next = add nuw i8 %iv, 1 + %cmp = icmp ult i8 %iv.next, 10 + br i1 %cmp, label %loop, label %exit + exit: + ret i8 %umax + } + )"); + + auto *F = M->getFunction("test"); + auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "umax")); + auto *UMaxAcc = &cast<PHINode>(findInstructionByName(F, "umax.acc")); + PHINode *PN; + Value *Init, *OtherOp; + EXPECT_TRUE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp)); + EXPECT_EQ(UMaxAcc, PN); + EXPECT_EQ(F->getArg(0), Init); + EXPECT_EQ(F->getArg(1), OtherOp); +} + +TEST_F(ValueTrackingTest, MatchBinaryIntrinsicRecurrenceNegativeFSHR) { + auto M = parseModule(R"( + define i8 @test(i8 %a, i8 %b, i8 %c) { + entry: + br label %loop + loop: + %iv = phi i8 [ %iv.next, %loop ], [ 0, %entry ] + %fshr.acc = phi i8 [ %fshr, %loop ], [ %a, %entry ] + %fshr = call i8 @llvm.fshr.i8(i8 %fshr.acc, i8 %b, i8 %c) + %iv.next = add nuw i8 %iv, 1 + %cmp = icmp ult i8 %iv.next, 10 + br i1 %cmp, label %loop, label %exit + exit: + ret i8 %fshr + } + )"); + + auto *F = M->getFunction("test"); + auto *II = &cast<IntrinsicInst>(findInstructionByName(F, "fshr")); + PHINode *PN; + Value *Init, *OtherOp; + EXPECT_FALSE(matchSimpleBinaryIntrinsicRecurrence(II, PN, Init, OtherOp)); +} + TEST_F(ComputeKnownBitsTest, ComputeKnownBits) { parseAssembly( "define i32 @test(i32 %a, i32 %b) {\n" |