diff options
Diffstat (limited to 'llvm/lib/Target/AArch64/AArch64ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/AArch64/AArch64ISelLowering.cpp | 40 |
1 files changed, 36 insertions, 4 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index be2f2e4..662d84b 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1561,6 +1561,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_ADD, VT, Custom); setOperationAction(ISD::VECREDUCE_AND, VT, Custom); setOperationAction(ISD::VECREDUCE_OR, VT, Custom); + setOperationAction(ISD::VECREDUCE_MUL, VT, Custom); setOperationAction(ISD::VECREDUCE_XOR, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); @@ -1717,6 +1718,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Custom); setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Custom); + setOperationAction(ISD::VECREDUCE_FMUL, VT, Custom); setOperationAction(ISD::VECTOR_SPLICE, VT, Custom); setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom); setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom); @@ -7775,6 +7777,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, case ISD::VECREDUCE_FMAXIMUM: case ISD::VECREDUCE_FMINIMUM: return LowerVECREDUCE(Op, DAG); + case ISD::VECREDUCE_MUL: + case ISD::VECREDUCE_FMUL: + return LowerVECREDUCE_MUL(Op, DAG); case ISD::ATOMIC_LOAD_AND: return LowerATOMIC_LOAD_AND(Op, DAG); case ISD::DYNAMIC_STACKALLOC: @@ -16794,6 +16799,33 @@ SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op, } } +SDValue AArch64TargetLowering::LowerVECREDUCE_MUL(SDValue Op, + SelectionDAG &DAG) const { + SDLoc DL(Op); + SDValue Src = Op.getOperand(0); + EVT SrcVT = Src.getValueType(); + assert(SrcVT.isScalableVector() && "Unexpected operand type!"); + + SDVTList SrcVTs = DAG.getVTList(SrcVT, SrcVT); + unsigned BaseOpc = ISD::getVecReduceBaseOpcode(Op.getOpcode()); + SDValue Identity = DAG.getNeutralElement(BaseOpc, DL, SrcVT, Op->getFlags()); + + // Whilst we don't know the size of the vector we do know the maximum size so + // can perform a tree reduction with an identity vector, which means once we + // arrive at the result the remaining stages (when the vector is smaller than + // the maximum) have no affect. + + unsigned Segments = AArch64::SVEMaxBitsPerVector / AArch64::SVEBitsPerBlock; + unsigned Stages = llvm::Log2_32(Segments * SrcVT.getVectorMinNumElements()); + + for (unsigned I = 0; I < Stages; ++I) { + Src = DAG.getNode(ISD::VECTOR_DEINTERLEAVE, DL, SrcVTs, Src, Identity); + Src = DAG.getNode(BaseOpc, DL, SrcVT, Src.getValue(0), Src.getValue(1)); + } + + return DAG.getExtractVectorElt(DL, Op.getValueType(), Src, 0); +} + SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op, SelectionDAG &DAG) const { auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>(); @@ -18144,8 +18176,8 @@ bool AArch64TargetLowering::lowerInterleavedStore(Instruction *Store, bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad( Instruction *Load, Value *Mask, IntrinsicInst *DI) const { const unsigned Factor = getDeinterleaveIntrinsicFactor(DI->getIntrinsicID()); - if (Factor != 2 && Factor != 4) { - LLVM_DEBUG(dbgs() << "Matching ld2 and ld4 patterns failed\n"); + if (Factor != 2 && Factor != 3 && Factor != 4) { + LLVM_DEBUG(dbgs() << "Matching ld2, ld3 and ld4 patterns failed\n"); return false; } auto *LI = dyn_cast<LoadInst>(Load); @@ -18223,8 +18255,8 @@ bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore( Instruction *Store, Value *Mask, ArrayRef<Value *> InterleavedValues) const { unsigned Factor = InterleavedValues.size(); - if (Factor != 2 && Factor != 4) { - LLVM_DEBUG(dbgs() << "Matching st2 and st4 patterns failed\n"); + if (Factor != 2 && Factor != 3 && Factor != 4) { + LLVM_DEBUG(dbgs() << "Matching st2, st3 and st4 patterns failed\n"); return false; } StoreInst *SI = dyn_cast<StoreInst>(Store); |