From 010f1083822397e5ec977c79ff08b8e93b18ec0f Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Fri, 30 Aug 2013 14:35:35 +0000 Subject: InstCombine: Check for zero shift amounts before subtracting one causing integer overflow. PR17026. Also avoid undefined shifts and shift amounts larger than 64 bits (those are always undef because we can't represent integer types that large). llvm-svn: 189672 --- .../InstCombine/InstCombineSimplifyDemanded.cpp | 25 +++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) (limited to 'llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp') diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a7bfe09..a2492d8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -845,21 +845,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, Instruction *Shl, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne) { - unsigned ShlAmt = cast(Shl->getOperand(1))->getZExtValue(); - unsigned ShrAmt = cast(Shr->getOperand(1))->getZExtValue(); + const APInt &ShlOp1 = cast(Shl->getOperand(1))->getValue(); + const APInt &ShrOp1 = cast(Shr->getOperand(1))->getValue(); + if (!ShlOp1 || !ShrOp1) + return 0; // Noop. + + Value *VarX = Shr->getOperand(0); + Type *Ty = VarX->getType(); + unsigned BitWidth = Ty->getIntegerBitWidth(); + if (ShlOp1.uge(BitWidth) || ShrOp1.uge(BitWidth)) + return 0; // Undef. + + unsigned ShlAmt = ShlOp1.getZExtValue(); + unsigned ShrAmt = ShrOp1.getZExtValue(); KnownOne.clearAllBits(); KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1); KnownZero &= DemandedMask; - if (ShlAmt == 0 || ShrAmt == 0) - return 0; - - Value *VarX = Shr->getOperand(0); - Type *Ty = VarX->getType(); - - APInt BitMask1(APInt::getAllOnesValue(Ty->getIntegerBitWidth())); - APInt BitMask2(APInt::getAllOnesValue(Ty->getIntegerBitWidth())); + APInt BitMask1(APInt::getAllOnesValue(BitWidth)); + APInt BitMask2(APInt::getAllOnesValue(BitWidth)); bool isLshr = (Shr->getOpcode() == Instruction::LShr); BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : -- cgit v1.1