diff options
Diffstat (limited to 'llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp')
-rw-r--r-- | llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp | 282 |
1 files changed, 161 insertions, 121 deletions
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp index 2378664..d96136c 100644 --- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp +++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp @@ -2385,23 +2385,6 @@ SDValue LoongArchTargetLowering::lowerBF16_TO_FP(SDValue Op, return Res; } -static bool isConstantOrUndef(const SDValue Op) { - if (Op->isUndef()) - return true; - if (isa<ConstantSDNode>(Op)) - return true; - if (isa<ConstantFPSDNode>(Op)) - return true; - return false; -} - -static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) { - for (unsigned i = 0; i < Op->getNumOperands(); ++i) - if (isConstantOrUndef(Op->getOperand(i))) - return true; - return false; -} - // Lower BUILD_VECTOR as broadcast load (if possible). // For example: // %a = load i8, ptr %ptr @@ -2451,10 +2434,14 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const { BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op); EVT ResTy = Op->getValueType(0); + unsigned NumElts = ResTy.getVectorNumElements(); SDLoc DL(Op); APInt SplatValue, SplatUndef; unsigned SplatBitSize; bool HasAnyUndefs; + bool IsConstant = false; + bool UseSameConstant = true; + SDValue ConstantValue; bool Is128Vec = ResTy.is128BitVector(); bool Is256Vec = ResTy.is256BitVector(); @@ -2505,19 +2492,45 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op, if (DAG.isSplatValue(Op, /*AllowUndefs=*/false)) return Op; - if (!isConstantOrUndefBUILD_VECTOR(Node)) { + for (unsigned i = 0; i < NumElts; ++i) { + SDValue Opi = Node->getOperand(i); + if (isIntOrFPConstant(Opi)) { + IsConstant = true; + if (!ConstantValue.getNode()) + ConstantValue = Opi; + else if (ConstantValue != Opi) + UseSameConstant = false; + } + } + + // If the type of BUILD_VECTOR is v2f64, custom legalizing it has no benefits. + if (IsConstant && UseSameConstant && ResTy != MVT::v2f64) { + SDValue Result = DAG.getSplatBuildVector(ResTy, DL, ConstantValue); + for (unsigned i = 0; i < NumElts; ++i) { + SDValue Opi = Node->getOperand(i); + if (!isIntOrFPConstant(Opi)) + Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Result, Opi, + DAG.getConstant(i, DL, Subtarget.getGRLenVT())); + } + return Result; + } + + if (!IsConstant) { // Use INSERT_VECTOR_ELT operations rather than expand to stores. // The resulting code is the same length as the expansion, but it doesn't // use memory operations. - EVT ResTy = Node->getValueType(0); - assert(ResTy.isVector()); - unsigned NumElts = ResTy.getVectorNumElements(); + SDValue Op0 = Node->getOperand(0); SDValue Vector = DAG.getUNDEF(ResTy); - for (unsigned i = 0; i < NumElts; ++i) { - Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector, - Node->getOperand(i), + + if (!Op0.isUndef()) + Vector = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ResTy, Op0); + for (unsigned i = 1; i < NumElts; ++i) { + SDValue Opi = Node->getOperand(i); + if (Opi.isUndef()) + continue; + Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector, Opi, DAG.getConstant(i, DL, Subtarget.getGRLenVT())); } return Vector; @@ -4560,6 +4573,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT, llvm_unreachable("Unexpected node type for vXi1 sign extension"); } +static SDValue +performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG, + TargetLowering::DAGCombinerInfo &DCI, + const LoongArchSubtarget &Subtarget) { + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Src = N->getOperand(0); + EVT SrcVT = Src.getValueType(); + + if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse()) + return SDValue(); + + bool UseLASX; + unsigned Opc = ISD::DELETED_NODE; + EVT CmpVT = Src.getOperand(0).getValueType(); + EVT EltVT = CmpVT.getVectorElementType(); + + if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128) + UseLASX = false; + else if (Subtarget.has32S() && Subtarget.hasExtLASX() && + CmpVT.getSizeInBits() == 256) + UseLASX = true; + else + return SDValue(); + + SDValue SrcN1 = Src.getOperand(1); + switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { + default: + break; + case ISD::SETEQ: + // x == 0 => not (vmsknez.b x) + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; + break; + case ISD::SETGT: + // x > -1 => vmskgez.b x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETGE: + // x >= 0 => vmskgez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; + break; + case ISD::SETLT: + // x < 0 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETLE: + // x <= -1 => vmskltz.{b,h,w,d} x + if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && + (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || + EltVT == MVT::i64)) + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; + break; + case ISD::SETNE: + // x != 0 => vmsknez.b x + if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) + Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; + break; + } + + if (Opc == ISD::DELETED_NODE) + return SDValue(); + + SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0)); + EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); + V = DAG.getZExtOrTrunc(V, DL, T); + return DAG.getBitcast(VT, V); +} + static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, TargetLowering::DAGCombinerInfo &DCI, const LoongArchSubtarget &Subtarget) { @@ -4574,110 +4661,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG, if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1) return SDValue(); - unsigned Opc = ISD::DELETED_NODE; // Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible + SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget); + if (Res) + return Res; + + // Generate vXi1 using [X]VMSKLTZ + MVT SExtVT; + unsigned Opc; + bool UseLASX = false; + bool PropagateSExt = false; + if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) { - bool UseLASX; EVT CmpVT = Src.getOperand(0).getValueType(); - EVT EltVT = CmpVT.getVectorElementType(); - - if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128) - UseLASX = false; - else if (Subtarget.has32S() && Subtarget.hasExtLASX() && - CmpVT.getSizeInBits() <= 256) - UseLASX = true; - else + if (CmpVT.getSizeInBits() > 256) return SDValue(); - - SDValue SrcN1 = Src.getOperand(1); - switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) { - default: - break; - case ISD::SETEQ: - // x == 0 => not (vmsknez.b x) - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ; - break; - case ISD::SETGT: - // x > -1 => vmskgez.b x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETGE: - // x >= 0 => vmskgez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ; - break; - case ISD::SETLT: - // x < 0 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETLE: - // x <= -1 => vmskltz.{b,h,w,d} x - if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && - (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 || - EltVT == MVT::i64)) - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - break; - case ISD::SETNE: - // x != 0 => vmsknez.b x - if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8) - Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ; - break; - } } - // Generate vXi1 using [X]VMSKLTZ - if (Opc == ISD::DELETED_NODE) { - MVT SExtVT; - bool UseLASX = false; - bool PropagateSExt = false; - switch (SrcVT.getSimpleVT().SimpleTy) { - default: - return SDValue(); - case MVT::v2i1: - SExtVT = MVT::v2i64; - break; - case MVT::v4i1: - SExtVT = MVT::v4i32; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v4i64; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v8i1: - SExtVT = MVT::v8i16; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v8i32; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v16i1: - SExtVT = MVT::v16i8; - if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { - SExtVT = MVT::v16i16; - UseLASX = true; - PropagateSExt = true; - } - break; - case MVT::v32i1: - SExtVT = MVT::v32i8; + switch (SrcVT.getSimpleVT().SimpleTy) { + default: + return SDValue(); + case MVT::v2i1: + SExtVT = MVT::v2i64; + break; + case MVT::v4i1: + SExtVT = MVT::v4i32; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v4i64; UseLASX = true; - break; - }; - if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX()) - return SDValue(); - Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) - : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); - Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; - } else { - Src = Src.getOperand(0); - } + PropagateSExt = true; + } + break; + case MVT::v8i1: + SExtVT = MVT::v8i16; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v8i32; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v16i1: + SExtVT = MVT::v16i8; + if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) { + SExtVT = MVT::v16i16; + UseLASX = true; + PropagateSExt = true; + } + break; + case MVT::v32i1: + SExtVT = MVT::v32i8; + UseLASX = true; + break; + }; + if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX())) + return SDValue(); + Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL) + : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src); + Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ; SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src); EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements()); |