diff options
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 6d7c279..f6c7cab 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -54490,25 +54490,6 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, if (Op0.getOpcode() == X86ISD::VBROADCAST) return DAG.getNode(Op0.getOpcode(), DL, VT, Op0.getOperand(0)); - // If this simple subvector or scalar/subvector broadcast_load is inserted - // into both halves, use a larger broadcast_load. Update other uses to use - // an extracted subvector. - if (ISD::isNormalLoad(Op0.getNode()) || - Op0.getOpcode() == X86ISD::VBROADCAST_LOAD || - Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) { - auto *Mem = cast<MemSDNode>(Op0); - unsigned Opc = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD - ? X86ISD::VBROADCAST_LOAD - : X86ISD::SUBV_BROADCAST_LOAD; - if (SDValue BcastLd = - getBROADCAST_LOAD(Opc, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) { - SDValue BcastSrc = - extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()); - DAG.ReplaceAllUsesOfValueWith(Op0, BcastSrc); - return BcastLd; - } - } - // concat_vectors(movddup(x),movddup(x)) -> broadcast(x) if (Op0.getOpcode() == X86ISD::MOVDDUP && VT == MVT::v4f64 && (Subtarget.hasAVX2() || @@ -54995,6 +54976,28 @@ static SDValue combineConcatVectorOps(const SDLoc &DL, MVT VT, } } + // If this simple subvector or scalar/subvector broadcast_load is inserted + // into both halves, use a larger broadcast_load. Update other uses to use + // an extracted subvector. + if (IsSplat && + (VT.is256BitVector() || (VT.is512BitVector() && Subtarget.hasAVX512()))) { + if (ISD::isNormalLoad(Op0.getNode()) || + Op0.getOpcode() == X86ISD::VBROADCAST_LOAD || + Op0.getOpcode() == X86ISD::SUBV_BROADCAST_LOAD) { + auto *Mem = cast<MemSDNode>(Op0); + unsigned Opc = Op0.getOpcode() == X86ISD::VBROADCAST_LOAD + ? X86ISD::VBROADCAST_LOAD + : X86ISD::SUBV_BROADCAST_LOAD; + if (SDValue BcastLd = + getBROADCAST_LOAD(Opc, DL, VT, Mem->getMemoryVT(), Mem, 0, DAG)) { + SDValue BcastSrc = + extractSubVector(BcastLd, 0, DAG, DL, Op0.getValueSizeInBits()); + DAG.ReplaceAllUsesOfValueWith(Op0, BcastSrc); + return BcastLd; + } + } + } + // Attempt to fold target constant loads. if (all_of(Ops, [](SDValue Op) { return getTargetConstantFromNode(Op); })) { SmallVector<APInt> EltBits; |