aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp60
1 files changed, 58 insertions, 2 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index aa030294..127a506 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -60,6 +60,58 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}
+/// Let N = 2 * M.
+/// Given an N-bit integer representing a pack of two M-bit integers,
+/// we can select one of the packed integers by right-shifting by either
+/// zero or M (which is the most straightforward to check if M is a power
+/// of 2), and then isolating the lower M bits. In this case, we can
+/// represent the shift as a select on whether the shr amount is nonzero.
+static Value *simplifyShiftSelectingPackedElement(Instruction *I,
+ const APInt &DemandedMask,
+ InstCombinerImpl &IC,
+ unsigned Depth) {
+ assert(I->getOpcode() == Instruction::LShr &&
+ "Only lshr instruction supported");
+
+ uint64_t ShlAmt;
+ Value *Upper, *Lower;
+ if (!match(I->getOperand(0),
+ m_OneUse(m_c_DisjointOr(
+ m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))),
+ m_Value(Lower)))))
+ return nullptr;
+
+ if (!isPowerOf2_64(ShlAmt))
+ return nullptr;
+
+ const uint64_t DemandedBitWidth = DemandedMask.getActiveBits();
+ if (DemandedBitWidth > ShlAmt)
+ return nullptr;
+
+ // Check that upper demanded bits are not lost from lshift.
+ if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth)
+ return nullptr;
+
+ KnownBits KnownLowerBits = IC.computeKnownBits(Lower, I, Depth);
+ if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt))
+ return nullptr;
+
+ Value *ShrAmt = I->getOperand(1);
+ KnownBits KnownShrBits = IC.computeKnownBits(ShrAmt, I, Depth);
+
+ // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or
+ // zero.
+ if (~KnownShrBits.Zero != ShlAmt)
+ return nullptr;
+
+ Value *ShrAmtZ =
+ IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(ShrAmt->getType()),
+ ShrAmt->getName() + ".z");
+ Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper);
+ Select->takeName(I);
+ return Select;
+}
+
/// Returns the bitwidth of the given scalar or pointer type. For vector types,
/// returns the element type's bitwidth.
static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
@@ -798,9 +850,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
Known >>= ShiftAmt;
if (ShiftAmt)
Known.Zero.setHighBits(ShiftAmt); // high bits known zero.
- } else {
- llvm::computeKnownBits(I, Known, Q, Depth);
+ break;
}
+ if (Value *V =
+ simplifyShiftSelectingPackedElement(I, DemandedMask, *this, Depth))
+ return V;
+
+ llvm::computeKnownBits(I, Known, Q, Depth);
break;
}
case Instruction::AShr: {