diff options
Diffstat (limited to 'llvm/lib/Target/X86/X86ISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/X86/X86ISelLowering.cpp | 74 |
1 files changed, 27 insertions, 47 deletions
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index a0b64ff..b05d7c7 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -29755,65 +29755,30 @@ static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl, const X86Subtarget &Subtarget, SelectionDAG &DAG, SDValue *Low = nullptr) { - unsigned NumElts = VT.getVectorNumElements(); - // For vXi8 we will unpack the low and high half of each 128 bit lane to widen // to a vXi16 type. Do the multiplies, shift the results and pack the half // lane results back together. // We'll take different approaches for signed and unsigned. - // For unsigned we'll use punpcklbw/punpckhbw to put zero extend the bytes - // and use pmullw to calculate the full 16-bit product. + // For unsigned we'll use punpcklbw/punpckhbw to zero extend the bytes to + // words and use pmullw to calculate the full 16-bit product. // For signed we'll use punpcklbw/punpckbw to extend the bytes to words and // shift them left into the upper byte of each word. This allows us to use // pmulhw to calculate the full 16-bit product. This trick means we don't // need to sign extend the bytes to use pmullw. - - MVT ExVT = MVT::getVectorVT(MVT::i16, NumElts / 2); + MVT ExVT = MVT::getVectorVT(MVT::i16, VT.getVectorNumElements() / 2); SDValue Zero = DAG.getConstant(0, dl, VT); - SDValue ALo, AHi; + SDValue ALo, AHi, BLo, BHi; if (IsSigned) { ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, A)); - AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A)); - } else { - ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero)); - AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero)); - } - - SDValue BLo, BHi; - if (ISD::isBuildVectorOfConstantSDNodes(B.getNode())) { - // If the RHS is a constant, manually unpackl/unpackh and extend. - SmallVector<SDValue, 16> LoOps, HiOps; - for (unsigned i = 0; i != NumElts; i += 16) { - for (unsigned j = 0; j != 8; ++j) { - SDValue LoOp = B.getOperand(i + j); - SDValue HiOp = B.getOperand(i + j + 8); - - if (IsSigned) { - LoOp = DAG.getAnyExtOrTrunc(LoOp, dl, MVT::i16); - HiOp = DAG.getAnyExtOrTrunc(HiOp, dl, MVT::i16); - LoOp = DAG.getNode(ISD::SHL, dl, MVT::i16, LoOp, - DAG.getConstant(8, dl, MVT::i16)); - HiOp = DAG.getNode(ISD::SHL, dl, MVT::i16, HiOp, - DAG.getConstant(8, dl, MVT::i16)); - } else { - LoOp = DAG.getZExtOrTrunc(LoOp, dl, MVT::i16); - HiOp = DAG.getZExtOrTrunc(HiOp, dl, MVT::i16); - } - - LoOps.push_back(LoOp); - HiOps.push_back(HiOp); - } - } - - BLo = DAG.getBuildVector(ExVT, dl, LoOps); - BHi = DAG.getBuildVector(ExVT, dl, HiOps); - } else if (IsSigned) { BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, Zero, B)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, A)); BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, Zero, B)); } else { + ALo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, A, Zero)); BLo = DAG.getBitcast(ExVT, getUnpackl(DAG, dl, VT, B, Zero)); + AHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, A, Zero)); BHi = DAG.getBitcast(ExVT, getUnpackh(DAG, dl, VT, B, Zero)); } @@ -29826,7 +29791,7 @@ static SDValue LowervXi8MulWithUNPCK(SDValue A, SDValue B, const SDLoc &dl, if (Low) *Low = getPack(DAG, Subtarget, dl, VT, RLo, RHi); - return getPack(DAG, Subtarget, dl, VT, RLo, RHi, /*PackHiHalf*/ true); + return getPack(DAG, Subtarget, dl, VT, RLo, RHi, /*PackHiHalf=*/true); } static SDValue LowerMULH(SDValue Op, const X86Subtarget &Subtarget, @@ -44848,10 +44813,16 @@ bool X86TargetLowering::SimplifyDemandedBitsForTargetNode( } case X86ISD::PCMPGT: // icmp sgt(0, R) == ashr(R, BitWidth-1). - // iff we only need the sign bit then we can use R directly. - if (OriginalDemandedBits.isSignMask() && - ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) - return TLO.CombineTo(Op, Op.getOperand(1)); + if (ISD::isBuildVectorAllZeros(Op.getOperand(0).getNode())) { + // iff we only need the signbit then we can use R directly. + if (OriginalDemandedBits.isSignMask()) + return TLO.CombineTo(Op, Op.getOperand(1)); + // otherwise we just need R's signbit for the comparison. + APInt SignMask = APInt::getSignMask(BitWidth); + if (SimplifyDemandedBits(Op.getOperand(1), SignMask, OriginalDemandedElts, + Known, TLO, Depth + 1)) + return true; + } break; case X86ISD::MOVMSK: { SDValue Src = Op.getOperand(0); @@ -47761,6 +47732,15 @@ static SDValue combineSelect(SDNode *N, SelectionDAG &DAG, DL, DAG, Subtarget)) return V; + // If the sign bit is known then BLENDV can be folded away. + if (N->getOpcode() == X86ISD::BLENDV) { + KnownBits KnownCond = DAG.computeKnownBits(Cond); + if (KnownCond.isNegative()) + return LHS; + if (KnownCond.isNonNegative()) + return RHS; + } + if (N->getOpcode() == ISD::VSELECT || N->getOpcode() == X86ISD::BLENDV) { SmallVector<int, 64> CondMask; if (createShuffleMaskFromVSELECT(CondMask, Cond, |