diff options
author | Simon Pilgrim <llvm-dev@redking.me.uk> | 2024-09-28 17:52:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-28 17:52:10 +0100 |
commit | 795c24c6fb4f9635c912f4084fa9339ea068c3d5 (patch) | |
tree | bfedb58f6a139c882d729728e127a40b15ca7ae3 /llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | |
parent | 5d734fa4c8f358299a4dfd2a7f9315a226b94a4a (diff) | |
download | llvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.zip llvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.tar.gz llvm-795c24c6fb4f9635c912f4084fa9339ea068c3d5.tar.bz2 |
[InstCombine] foldVecExtTruncToExtElt - extend to handle trunc(lshr(extractelement(x,c1),c2)) -> extractelement(bitcast(x),c3) patterns. (#109689)
This patch moves the existing trunc+extractlement -> extractelement+bitcast fold into a foldVecExtTruncToExtElt helper and extends the helper to handle trunc+lshr+extractelement cases as well.
Fixes #107404
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp | 97 |
1 files changed, 67 insertions, 30 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index ea51d77..9934c06 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -436,6 +436,71 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } +/// Whenever an element is extracted from a vector, optionally shifted down, and +/// then truncated, canonicalize by converting it to a bitcast followed by an +/// extractelement. +/// +/// Examples (little endian): +/// trunc (extractelement <4 x i64> %X, 0) to i32 +/// ---> +/// extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 +/// +/// trunc (lshr (extractelement <4 x i32> %X, 0), 8) to i8 +/// ---> +/// extractelement <16 x i8> (bitcast <4 x i32> %X to <16 x i8>), i32 1 +static Instruction *foldVecExtTruncToExtElt(TruncInst &Trunc, + InstCombinerImpl &IC) { + Value *Src = Trunc.getOperand(0); + Type *SrcType = Src->getType(); + Type *DstType = Trunc.getType(); + + // Only attempt this if we have simple aliasing of the vector elements. + // A badly fit destination size would result in an invalid cast. + unsigned SrcBits = SrcType->getScalarSizeInBits(); + unsigned DstBits = DstType->getScalarSizeInBits(); + unsigned TruncRatio = SrcBits / DstBits; + if ((SrcBits % DstBits) != 0) + return nullptr; + + Value *VecOp; + ConstantInt *Cst; + const APInt *ShiftAmount = nullptr; + if (!match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)))) && + !match(Src, + m_OneUse(m_LShr(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst)), + m_APInt(ShiftAmount))))) + return nullptr; + + auto *VecOpTy = cast<VectorType>(VecOp->getType()); + auto VecElts = VecOpTy->getElementCount(); + + uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; + uint64_t VecOpIdx = Cst->getZExtValue(); + uint64_t NewIdx = IC.getDataLayout().isBigEndian() + ? (VecOpIdx + 1) * TruncRatio - 1 + : VecOpIdx * TruncRatio; + + // Adjust index by the whole number of truncated elements. + if (ShiftAmount) { + // Check shift amount is in range and shifts a whole number of truncated + // elements. + if (ShiftAmount->uge(SrcBits) || ShiftAmount->urem(DstBits) != 0) + return nullptr; + + uint64_t IdxOfs = ShiftAmount->udiv(DstBits).getZExtValue(); + NewIdx = IC.getDataLayout().isBigEndian() ? (NewIdx - IdxOfs) + : (NewIdx + IdxOfs); + } + + assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && + NewIdx <= std::numeric_limits<uint32_t>::max() && "overflow 32-bits"); + + auto *BitCastTo = + VectorType::get(DstType, BitCastNumElts, VecElts.isScalable()); + Value *BitCast = IC.Builder.CreateBitCast(VecOp, BitCastTo); + return ExtractElementInst::Create(BitCast, IC.Builder.getInt32(NewIdx)); +} + /// Funnel/Rotate left/right may occur in a wider type than necessary because of /// type promotion rules. Try to narrow the inputs and convert to funnel shift. Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { @@ -848,36 +913,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { if (Instruction *I = foldVecTruncToExtElt(Trunc, *this)) return I; - // Whenever an element is extracted from a vector, and then truncated, - // canonicalize by converting it to a bitcast followed by an - // extractelement. - // - // Example (little endian): - // trunc (extractelement <4 x i64> %X, 0) to i32 - // ---> - // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 - Value *VecOp; - ConstantInt *Cst; - if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { - auto *VecOpTy = cast<VectorType>(VecOp->getType()); - auto VecElts = VecOpTy->getElementCount(); - - // A badly fit destination size would result in an invalid cast. - if (SrcWidth % DestWidth == 0) { - uint64_t TruncRatio = SrcWidth / DestWidth; - uint64_t BitCastNumElts = VecElts.getKnownMinValue() * TruncRatio; - uint64_t VecOpIdx = Cst->getZExtValue(); - uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 - : VecOpIdx * TruncRatio; - assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && - "overflow 32-bits"); - - auto *BitCastTo = - VectorType::get(DestTy, BitCastNumElts, VecElts.isScalable()); - Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); - return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); - } - } + if (Instruction *I = foldVecExtTruncToExtElt(Trunc, *this)) + return I; // trunc (ctlz_i32(zext(A), B) --> add(ctlz_i16(A, B), C) if (match(Src, m_OneUse(m_Intrinsic<Intrinsic::ctlz>(m_ZExt(m_Value(A)), |