aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp')
-rw-r--r--llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp282
1 files changed, 161 insertions, 121 deletions
diff --git a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
index 2378664..d96136c 100644
--- a/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
+++ b/llvm/lib/Target/LoongArch/LoongArchISelLowering.cpp
@@ -2385,23 +2385,6 @@ SDValue LoongArchTargetLowering::lowerBF16_TO_FP(SDValue Op,
return Res;
}
-static bool isConstantOrUndef(const SDValue Op) {
- if (Op->isUndef())
- return true;
- if (isa<ConstantSDNode>(Op))
- return true;
- if (isa<ConstantFPSDNode>(Op))
- return true;
- return false;
-}
-
-static bool isConstantOrUndefBUILD_VECTOR(const BuildVectorSDNode *Op) {
- for (unsigned i = 0; i < Op->getNumOperands(); ++i)
- if (isConstantOrUndef(Op->getOperand(i)))
- return true;
- return false;
-}
-
// Lower BUILD_VECTOR as broadcast load (if possible).
// For example:
// %a = load i8, ptr %ptr
@@ -2451,10 +2434,14 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
BuildVectorSDNode *Node = cast<BuildVectorSDNode>(Op);
EVT ResTy = Op->getValueType(0);
+ unsigned NumElts = ResTy.getVectorNumElements();
SDLoc DL(Op);
APInt SplatValue, SplatUndef;
unsigned SplatBitSize;
bool HasAnyUndefs;
+ bool IsConstant = false;
+ bool UseSameConstant = true;
+ SDValue ConstantValue;
bool Is128Vec = ResTy.is128BitVector();
bool Is256Vec = ResTy.is256BitVector();
@@ -2505,19 +2492,45 @@ SDValue LoongArchTargetLowering::lowerBUILD_VECTOR(SDValue Op,
if (DAG.isSplatValue(Op, /*AllowUndefs=*/false))
return Op;
- if (!isConstantOrUndefBUILD_VECTOR(Node)) {
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue Opi = Node->getOperand(i);
+ if (isIntOrFPConstant(Opi)) {
+ IsConstant = true;
+ if (!ConstantValue.getNode())
+ ConstantValue = Opi;
+ else if (ConstantValue != Opi)
+ UseSameConstant = false;
+ }
+ }
+
+ // If the type of BUILD_VECTOR is v2f64, custom legalizing it has no benefits.
+ if (IsConstant && UseSameConstant && ResTy != MVT::v2f64) {
+ SDValue Result = DAG.getSplatBuildVector(ResTy, DL, ConstantValue);
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue Opi = Node->getOperand(i);
+ if (!isIntOrFPConstant(Opi))
+ Result = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Result, Opi,
+ DAG.getConstant(i, DL, Subtarget.getGRLenVT()));
+ }
+ return Result;
+ }
+
+ if (!IsConstant) {
// Use INSERT_VECTOR_ELT operations rather than expand to stores.
// The resulting code is the same length as the expansion, but it doesn't
// use memory operations.
- EVT ResTy = Node->getValueType(0);
-
assert(ResTy.isVector());
- unsigned NumElts = ResTy.getVectorNumElements();
+ SDValue Op0 = Node->getOperand(0);
SDValue Vector = DAG.getUNDEF(ResTy);
- for (unsigned i = 0; i < NumElts; ++i) {
- Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector,
- Node->getOperand(i),
+
+ if (!Op0.isUndef())
+ Vector = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ResTy, Op0);
+ for (unsigned i = 1; i < NumElts; ++i) {
+ SDValue Opi = Node->getOperand(i);
+ if (Opi.isUndef())
+ continue;
+ Vector = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ResTy, Vector, Opi,
DAG.getConstant(i, DL, Subtarget.getGRLenVT()));
}
return Vector;
@@ -4560,6 +4573,80 @@ static SDValue signExtendBitcastSrcVector(SelectionDAG &DAG, EVT SExtVT,
llvm_unreachable("Unexpected node type for vXi1 sign extension");
}
+static SDValue
+performSETCC_BITCASTCombine(SDNode *N, SelectionDAG &DAG,
+ TargetLowering::DAGCombinerInfo &DCI,
+ const LoongArchSubtarget &Subtarget) {
+ SDLoc DL(N);
+ EVT VT = N->getValueType(0);
+ SDValue Src = N->getOperand(0);
+ EVT SrcVT = Src.getValueType();
+
+ if (Src.getOpcode() != ISD::SETCC || !Src.hasOneUse())
+ return SDValue();
+
+ bool UseLASX;
+ unsigned Opc = ISD::DELETED_NODE;
+ EVT CmpVT = Src.getOperand(0).getValueType();
+ EVT EltVT = CmpVT.getVectorElementType();
+
+ if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() == 128)
+ UseLASX = false;
+ else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
+ CmpVT.getSizeInBits() == 256)
+ UseLASX = true;
+ else
+ return SDValue();
+
+ SDValue SrcN1 = Src.getOperand(1);
+ switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
+ default:
+ break;
+ case ISD::SETEQ:
+ // x == 0 => not (vmsknez.b x)
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
+ break;
+ case ISD::SETGT:
+ // x > -1 => vmskgez.b x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETGE:
+ // x >= 0 => vmskgez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
+ break;
+ case ISD::SETLT:
+ // x < 0 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETLE:
+ // x <= -1 => vmskltz.{b,h,w,d} x
+ if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
+ (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
+ EltVT == MVT::i64))
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
+ break;
+ case ISD::SETNE:
+ // x != 0 => vmsknez.b x
+ if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
+ Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
+ break;
+ }
+
+ if (Opc == ISD::DELETED_NODE)
+ return SDValue();
+
+ SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src.getOperand(0));
+ EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());
+ V = DAG.getZExtOrTrunc(V, DL, T);
+ return DAG.getBitcast(VT, V);
+}
+
static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const LoongArchSubtarget &Subtarget) {
@@ -4574,110 +4661,63 @@ static SDValue performBITCASTCombine(SDNode *N, SelectionDAG &DAG,
if (!SrcVT.isSimple() || SrcVT.getScalarType() != MVT::i1)
return SDValue();
- unsigned Opc = ISD::DELETED_NODE;
// Combine SETCC and BITCAST into [X]VMSK{LT,GE,NE} when possible
+ SDValue Res = performSETCC_BITCASTCombine(N, DAG, DCI, Subtarget);
+ if (Res)
+ return Res;
+
+ // Generate vXi1 using [X]VMSKLTZ
+ MVT SExtVT;
+ unsigned Opc;
+ bool UseLASX = false;
+ bool PropagateSExt = false;
+
if (Src.getOpcode() == ISD::SETCC && Src.hasOneUse()) {
- bool UseLASX;
EVT CmpVT = Src.getOperand(0).getValueType();
- EVT EltVT = CmpVT.getVectorElementType();
-
- if (Subtarget.hasExtLSX() && CmpVT.getSizeInBits() <= 128)
- UseLASX = false;
- else if (Subtarget.has32S() && Subtarget.hasExtLASX() &&
- CmpVT.getSizeInBits() <= 256)
- UseLASX = true;
- else
+ if (CmpVT.getSizeInBits() > 256)
return SDValue();
-
- SDValue SrcN1 = Src.getOperand(1);
- switch (cast<CondCodeSDNode>(Src.getOperand(2))->get()) {
- default:
- break;
- case ISD::SETEQ:
- // x == 0 => not (vmsknez.b x)
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKEQZ : LoongArchISD::VMSKEQZ;
- break;
- case ISD::SETGT:
- // x > -1 => vmskgez.b x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETGE:
- // x >= 0 => vmskgez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKGEZ : LoongArchISD::VMSKGEZ;
- break;
- case ISD::SETLT:
- // x < 0 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETLE:
- // x <= -1 => vmskltz.{b,h,w,d} x
- if (ISD::isBuildVectorAllOnes(SrcN1.getNode()) &&
- (EltVT == MVT::i8 || EltVT == MVT::i16 || EltVT == MVT::i32 ||
- EltVT == MVT::i64))
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- break;
- case ISD::SETNE:
- // x != 0 => vmsknez.b x
- if (ISD::isBuildVectorAllZeros(SrcN1.getNode()) && EltVT == MVT::i8)
- Opc = UseLASX ? LoongArchISD::XVMSKNEZ : LoongArchISD::VMSKNEZ;
- break;
- }
}
- // Generate vXi1 using [X]VMSKLTZ
- if (Opc == ISD::DELETED_NODE) {
- MVT SExtVT;
- bool UseLASX = false;
- bool PropagateSExt = false;
- switch (SrcVT.getSimpleVT().SimpleTy) {
- default:
- return SDValue();
- case MVT::v2i1:
- SExtVT = MVT::v2i64;
- break;
- case MVT::v4i1:
- SExtVT = MVT::v4i32;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v4i64;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v8i1:
- SExtVT = MVT::v8i16;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v8i32;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v16i1:
- SExtVT = MVT::v16i8;
- if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
- SExtVT = MVT::v16i16;
- UseLASX = true;
- PropagateSExt = true;
- }
- break;
- case MVT::v32i1:
- SExtVT = MVT::v32i8;
+ switch (SrcVT.getSimpleVT().SimpleTy) {
+ default:
+ return SDValue();
+ case MVT::v2i1:
+ SExtVT = MVT::v2i64;
+ break;
+ case MVT::v4i1:
+ SExtVT = MVT::v4i32;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v4i64;
UseLASX = true;
- break;
- };
- if (UseLASX && !Subtarget.has32S() && !Subtarget.hasExtLASX())
- return SDValue();
- Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
- : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
- Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
- } else {
- Src = Src.getOperand(0);
- }
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v8i1:
+ SExtVT = MVT::v8i16;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v8i32;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v16i1:
+ SExtVT = MVT::v16i8;
+ if (Subtarget.hasExtLASX() && checkBitcastSrcVectorSize(Src, 256, 0)) {
+ SExtVT = MVT::v16i16;
+ UseLASX = true;
+ PropagateSExt = true;
+ }
+ break;
+ case MVT::v32i1:
+ SExtVT = MVT::v32i8;
+ UseLASX = true;
+ break;
+ };
+ if (UseLASX && !(Subtarget.has32S() && Subtarget.hasExtLASX()))
+ return SDValue();
+ Src = PropagateSExt ? signExtendBitcastSrcVector(DAG, SExtVT, Src, DL)
+ : DAG.getNode(ISD::SIGN_EXTEND, DL, SExtVT, Src);
+ Opc = UseLASX ? LoongArchISD::XVMSKLTZ : LoongArchISD::VMSKLTZ;
SDValue V = DAG.getNode(Opc, DL, MVT::i64, Src);
EVT T = EVT::getIntegerVT(*DAG.getContext(), SrcVT.getVectorNumElements());