aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib
diff options
context:
space:
mode:
authorzGoldthorpe <Zach.Goldthorpe@amd.com>2025-10-08 06:58:11 -0600
committerGitHub <noreply@github.com>2025-10-08 06:58:11 -0600
commit7910ed22320c5f298c4645ffa9072238c95bc7d6 (patch)
treea7f69b8eac13762766ac477264c917d371fe316a /llvm/lib
parentf53b6249c24005d1a6208cd9e355595eb6519dc0 (diff)
downloadllvm-7910ed22320c5f298c4645ffa9072238c95bc7d6.zip
llvm-7910ed22320c5f298c4645ffa9072238c95bc7d6.tar.gz
llvm-7910ed22320c5f298c4645ffa9072238c95bc7d6.tar.bz2
[InstCombine] Canonicalise packed-integer-selecting shifts (#162147)
This patch resolves recent regressions related to [issue #92891](https://github.com/llvm/llvm-project/issues/92891). It specifically enables the following types of reductions. ```llvm define i16 @src(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { %upper.shl = shl nuw i32 %upper, 16 %pack = or disjoint i32 %upper.shl, %lower %mask.bit = and i32 %mask, 16 %sel = lshr i32 %pack, %mask.bit %trunc = trunc i32 %sel to i16 ret i16 %trunc } ; => define i16 @tgt(i32 %mask, i32 %upper, i32 range(i32 0, 65536) %lower) { %mask.bit = and i32 %mask, 16 %mask.bit.z = icmp eq i32 %mask.bit, 0 %sel = select i1 %mask.bit.z, i32 %lower, i32 %upper %trunc = trunc i32 %sel to i16 ret i16 %trunc } ``` Alive2 proofs: [gJ9MpP](https://alive2.llvm.org/ce/z/gJ9MpP)
Diffstat (limited to 'llvm/lib')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp60
1 files changed, 58 insertions, 2 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index aa030294..127a506 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -60,6 +60,58 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
return true;
}
+/// Let N = 2 * M.
+/// Given an N-bit integer representing a pack of two M-bit integers,
+/// we can select one of the packed integers by right-shifting by either
+/// zero or M (which is the most straightforward to check if M is a power
+/// of 2), and then isolating the lower M bits. In this case, we can
+/// represent the shift as a select on whether the shr amount is nonzero.
+static Value *simplifyShiftSelectingPackedElement(Instruction *I,
+ const APInt &DemandedMask,
+ InstCombinerImpl &IC,
+ unsigned Depth) {
+ assert(I->getOpcode() == Instruction::LShr &&
+ "Only lshr instruction supported");
+
+ uint64_t ShlAmt;
+ Value *Upper, *Lower;
+ if (!match(I->getOperand(0),
+ m_OneUse(m_c_DisjointOr(
+ m_OneUse(m_Shl(m_Value(Upper), m_ConstantInt(ShlAmt))),
+ m_Value(Lower)))))
+ return nullptr;
+
+ if (!isPowerOf2_64(ShlAmt))
+ return nullptr;
+
+ const uint64_t DemandedBitWidth = DemandedMask.getActiveBits();
+ if (DemandedBitWidth > ShlAmt)
+ return nullptr;
+
+ // Check that upper demanded bits are not lost from lshift.
+ if (Upper->getType()->getScalarSizeInBits() < ShlAmt + DemandedBitWidth)
+ return nullptr;
+
+ KnownBits KnownLowerBits = IC.computeKnownBits(Lower, I, Depth);
+ if (!KnownLowerBits.getMaxValue().isIntN(ShlAmt))
+ return nullptr;
+
+ Value *ShrAmt = I->getOperand(1);
+ KnownBits KnownShrBits = IC.computeKnownBits(ShrAmt, I, Depth);
+
+ // Verify that ShrAmt is either exactly ShlAmt (which is a power of 2) or
+ // zero.
+ if (~KnownShrBits.Zero != ShlAmt)
+ return nullptr;
+
+ Value *ShrAmtZ =
+ IC.Builder.CreateICmpEQ(ShrAmt, Constant::getNullValue(ShrAmt->getType()),
+ ShrAmt->getName() + ".z");
+ Value *Select = IC.Builder.CreateSelect(ShrAmtZ, Lower, Upper);
+ Select->takeName(I);
+ return Select;
+}
+
/// Returns the bitwidth of the given scalar or pointer type. For vector types,
/// returns the element type's bitwidth.
static unsigned getBitWidth(Type *Ty, const DataLayout &DL) {
@@ -798,9 +850,13 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Instruction *I,
Known >>= ShiftAmt;
if (ShiftAmt)
Known.Zero.setHighBits(ShiftAmt); // high bits known zero.
- } else {
- llvm::computeKnownBits(I, Known, Q, Depth);
+ break;
}
+ if (Value *V =
+ simplifyShiftSelectingPackedElement(I, DemandedMask, *this, Depth))
+ return V;
+
+ llvm::computeKnownBits(I, Known, Q, Depth);
break;
}
case Instruction::AShr: {