aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
diff options
context:
space:
mode:
authorSimon Pilgrim <llvm-dev@redking.me.uk>2024-09-28 17:52:10 +0100
committerGitHub <noreply@github.com>2024-09-28 17:52:10 +0100
commit795c24c6fb4f9635c912f4084fa9339ea068c3d5 (patch)
treebfedb58f6a139c882d729728e127a40b15ca7ae3 /llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
parent5d734fa4c8f358299a4dfd2a7f9315a226b94a4a (diff)
downloadllvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.zip
llvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.tar.gz
llvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.tar.bz2
[InstCombine] foldVecExtTruncToExtElt - extend to handle trunc(lshr(extractelement(x,c1),c2)) -> extractelement(bitcast(x),c3) patterns. (#109689)
This patch moves the existing trunc+extractlement -> extractelement+bitcast fold into a foldVecExtTruncToExtElt helper and extends the helper to handle trunc+lshr+extractelement cases as well. Fixes #107404
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp97
1 files changed, 67 insertions, 30 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index ea51d77..9934c06 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -436,6 +436,71 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc,
return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
}
+/// Whenever an element is extracted from a vector, optionally shifted down, and
+/// then truncated, canonicalize by converting it to a bitcast followed by an
+/// extractelement.
+///
+/// Examples (little endian):
+/// trunc (extractelement <4 x i64> %X, 0) to i32
+/// --->
+/// extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
+///
+/// trunc (lshr (extractelement <4 x i32> %X, 0), 8) to i8
+/// --->
+/// extractelement <16 x i8> (bitcast <4 x i32> %X to <16 x i8>), i32 1
+static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc,
+ InstCombinerImpl &IC) {
+ Value *Src = Trunc.getOperand(0);
+ Type *SrcType = Src->getType();
+ Type *DstType = Trunc.getType();
+
+ // Only attempt this if we have simple aliasing of the vector elements.
+ // A badly fit destination size would result in an invalid cast.
+ unsigned SrcBits = SrcType->getScalarSizeInBits();
+ unsigned DstBits = DstType->getScalarSizeInBits();
+ unsigned TruncRatio = SrcBits / DstBits;
+ if ((SrcBits % DstBits) != 0)
+ return nullptr;
+
+ Value *VecOp;
+ ConstantInt *Cst;
+ const APInt *ShiftAmount = nullptr;
+ if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))) &&
+ !match(Src,
+ m_OneUse(m_LShr(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)),
+ m_APInt(ShiftAmount)))))
+ return nullptr;
+
+ auto *VecOpTy = cast<VectorType>(VecOp->getType());
+ auto VecElts = VecOpTy->getElementCount();
+
+ uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
+ uint64_t VecOpIdx = Cst->getZExtValue();
+ uint64_t NewIdx = IC.getDataLayout().isBigEndian()
+ ? (VecOpIdx + 1) * TruncRatio - 1
+ : VecOpIdx * TruncRatio;
+
+ // Adjust index by the whole number of truncated elements.
+ if (ShiftAmount) {
+ // Check shift amount is in range and shifts a whole number of truncated
+ // elements.
+ if (ShiftAmount->uge(SrcBits) || ShiftAmount->urem(DstBits) != 0)
+ return nullptr;
+
+ uint64_t IdxOfs = ShiftAmount->udiv(DstBits).getZExtValue();
+ NewIdx = IC.getDataLayout().isBigEndian() ? (NewIdx - IdxOfs)
+ : (NewIdx + IdxOfs);
+ }
+
+ assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
+ NewIdx <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits");
+
+ auto *BitCastTo =
+ VectorType::get(DstType, BitCastNumElts, VecElts.isScalable());
+ Value *BitCast = IC.Builder.CreateBitCast(VecOp, BitCastTo);
+ return ExtractElementInst::Create(BitCast, IC.Builder.getInt32(NewIdx));
+}
+
/// Funnel/Rotate left/right may occur in a wider type than necessary because of
/// type promotion rules. Try to narrow the inputs and convert to funnel shift.
Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) {
@@ -848,36 +913,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
if (Instruction *I = foldVecTruncToExtElt(Trunc, *this))
return I;
- // Whenever an element is extracted from a vector, and then truncated,
- // canonicalize by converting it to a bitcast followed by an
- // extractelement.
- //
- // Example (little endian):
- // trunc (extractelement <4 x i64> %X, 0) to i32
- // --->
- // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0
- Value *VecOp;
- ConstantInt *Cst;
- if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) {
- auto *VecOpTy = cast<VectorType>(VecOp->getType());
- auto VecElts = VecOpTy->getElementCount();
-
- // A badly fit destination size would result in an invalid cast.
- if (SrcWidth % DestWidth == 0) {
- uint64_t TruncRatio = SrcWidth / DestWidth;
- uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio;
- uint64_t VecOpIdx = Cst->getZExtValue();
- uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1
- : VecOpIdx * TruncRatio;
- assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() &&
- "overflow 32-bits");
-
- auto *BitCastTo =
- VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable());
- Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo);
- return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx));
- }
- }
+ if (Instruction *I = foldVecExtTruncToExtElt(Trunc, *this))
+ return I;
// trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C)
if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)),