aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/AMDGPU/SIISelLowering.cpp')
-rw-r--r--llvm/lib/Target/AMDGPU/SIISelLowering.cpp122
1 files changed, 101 insertions, 21 deletions
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 079cae0..209debb 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -167,8 +167,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
addRegisterClass(MVT::v4bf16, &AMDGPU::SReg_64RegClass);
addRegisterClass(MVT::v8i16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v8f16, &AMDGPU::SGPR_128RegClass);
+ addRegisterClass(MVT::v8bf16, &AMDGPU::SGPR_128RegClass);
addRegisterClass(MVT::v16i16, &AMDGPU::SGPR_256RegClass);
addRegisterClass(MVT::v16f16, &AMDGPU::SGPR_256RegClass);
+ addRegisterClass(MVT::v16bf16, &AMDGPU::SGPR_256RegClass);
addRegisterClass(MVT::v32i16, &AMDGPU::SGPR_512RegClass);
addRegisterClass(MVT::v32f16, &AMDGPU::SGPR_512RegClass);
}
@@ -310,13 +312,14 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
// We only support LOAD/STORE and vector manipulation ops for vectors
// with > 4 elements.
for (MVT VT :
- {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
- MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
- MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
- MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
- MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
- MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16, MVT::v16i64,
- MVT::v16f64, MVT::v32i32, MVT::v32f32, MVT::v32i16, MVT::v32f16}) {
+ {MVT::v8i32, MVT::v8f32, MVT::v9i32, MVT::v9f32, MVT::v10i32,
+ MVT::v10f32, MVT::v11i32, MVT::v11f32, MVT::v12i32, MVT::v12f32,
+ MVT::v16i32, MVT::v16f32, MVT::v2i64, MVT::v2f64, MVT::v4i16,
+ MVT::v4f16, MVT::v4bf16, MVT::v3i64, MVT::v3f64, MVT::v6i32,
+ MVT::v6f32, MVT::v4i64, MVT::v4f64, MVT::v8i64, MVT::v8f64,
+ MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
+ MVT::v16bf16, MVT::v16i64, MVT::v16f64, MVT::v32i32, MVT::v32f32,
+ MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op) {
switch (Op) {
case ISD::LOAD:
@@ -683,6 +686,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
AddPromotedToType(ISD::LOAD, MVT::v8i16, MVT::v4i32);
setOperationAction(ISD::LOAD, MVT::v8f16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v8f16, MVT::v4i32);
+ setOperationAction(ISD::LOAD, MVT::v8bf16, Promote);
+ AddPromotedToType(ISD::LOAD, MVT::v8bf16, MVT::v4i32);
setOperationAction(ISD::STORE, MVT::v4i16, Promote);
AddPromotedToType(ISD::STORE, MVT::v4i16, MVT::v2i32);
@@ -693,16 +698,22 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
AddPromotedToType(ISD::STORE, MVT::v8i16, MVT::v4i32);
setOperationAction(ISD::STORE, MVT::v8f16, Promote);
AddPromotedToType(ISD::STORE, MVT::v8f16, MVT::v4i32);
+ setOperationAction(ISD::STORE, MVT::v8bf16, Promote);
+ AddPromotedToType(ISD::STORE, MVT::v8bf16, MVT::v4i32);
setOperationAction(ISD::LOAD, MVT::v16i16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v16i16, MVT::v8i32);
setOperationAction(ISD::LOAD, MVT::v16f16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v16f16, MVT::v8i32);
+ setOperationAction(ISD::LOAD, MVT::v16bf16, Promote);
+ AddPromotedToType(ISD::LOAD, MVT::v16bf16, MVT::v8i32);
setOperationAction(ISD::STORE, MVT::v16i16, Promote);
AddPromotedToType(ISD::STORE, MVT::v16i16, MVT::v8i32);
setOperationAction(ISD::STORE, MVT::v16f16, Promote);
AddPromotedToType(ISD::STORE, MVT::v16f16, MVT::v8i32);
+ setOperationAction(ISD::STORE, MVT::v16bf16, Promote);
+ AddPromotedToType(ISD::STORE, MVT::v16bf16, MVT::v8i32);
setOperationAction(ISD::LOAD, MVT::v32i16, Promote);
AddPromotedToType(ISD::LOAD, MVT::v32i16, MVT::v16i32);
@@ -725,7 +736,8 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
MVT::v8i32, Expand);
if (!Subtarget->hasVOP3PInsts())
- setOperationAction(ISD::BUILD_VECTOR, {MVT::v2i16, MVT::v2f16}, Custom);
+ setOperationAction(ISD::BUILD_VECTOR,
+ {MVT::v2i16, MVT::v2f16, MVT::v2bf16}, Custom);
setOperationAction(ISD::FNEG, MVT::v2f16, Legal);
// This isn't really legal, but this avoids the legalizer unrolling it (and
@@ -743,8 +755,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
{MVT::v4f16, MVT::v8f16, MVT::v16f16, MVT::v32f16},
Expand);
- for (MVT Vec16 : {MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
- MVT::v32i16, MVT::v32f16}) {
+ for (MVT Vec16 :
+ {MVT::v8i16, MVT::v8f16, MVT::v8bf16, MVT::v16i16, MVT::v16f16,
+ MVT::v16bf16, MVT::v32i16, MVT::v32f16, MVT::v32bf16}) {
setOperationAction(
{ISD::BUILD_VECTOR, ISD::EXTRACT_VECTOR_ELT, ISD::SCALAR_TO_VECTOR},
Vec16, Custom);
@@ -814,13 +827,17 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
}
setOperationAction(ISD::SELECT,
- {MVT::v4i16, MVT::v4f16, MVT::v2i8, MVT::v4i8, MVT::v8i8,
- MVT::v8i16, MVT::v8f16, MVT::v16i16, MVT::v16f16,
- MVT::v32i16, MVT::v32f16},
+ {MVT::v4i16, MVT::v4f16, MVT::v4bf16, MVT::v2i8, MVT::v4i8,
+ MVT::v8i8, MVT::v8i16, MVT::v8f16, MVT::v8bf16,
+ MVT::v16i16, MVT::v16f16, MVT::v16bf16, MVT::v32i16,
+ MVT::v32f16, MVT::v32bf16},
Custom);
setOperationAction({ISD::SMULO, ISD::UMULO}, MVT::i64, Custom);
+ if (Subtarget->hasScalarSMulU64())
+ setOperationAction(ISD::MUL, MVT::i64, Custom);
+
if (Subtarget->hasMad64_32())
setOperationAction({ISD::SMUL_LOHI, ISD::UMUL_LOHI}, MVT::i32, Custom);
@@ -5431,7 +5448,9 @@ SDValue SITargetLowering::splitTernaryVectorOp(SDValue Op,
assert(VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
VT == MVT::v8f16 || VT == MVT::v4f32 || VT == MVT::v16i16 ||
VT == MVT::v16f16 || VT == MVT::v8f32 || VT == MVT::v16f32 ||
- VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16);
+ VT == MVT::v32f32 || VT == MVT::v32f16 || VT == MVT::v32i16 ||
+ VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v16bf16 ||
+ VT == MVT::v32bf16);
SDValue Lo0, Hi0;
SDValue Op0 = Op.getOperand(0);
@@ -5550,7 +5569,6 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::SRL:
case ISD::ADD:
case ISD::SUB:
- case ISD::MUL:
case ISD::SMIN:
case ISD::SMAX:
case ISD::UMIN:
@@ -5564,6 +5582,8 @@ SDValue SITargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
case ISD::SADDSAT:
case ISD::SSUBSAT:
return splitBinaryVectorOp(Op, DAG);
+ case ISD::MUL:
+ return lowerMUL(Op, DAG);
case ISD::SMULO:
case ISD::UMULO:
return lowerXMULO(Op, DAG);
@@ -6219,6 +6239,66 @@ SDValue SITargetLowering::lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const {
return DAG.getNode(ISD::FLDEXP, DL, VT, Op.getOperand(0), TruncExp);
}
+// Custom lowering for vector multiplications and s_mul_u64.
+SDValue SITargetLowering::lowerMUL(SDValue Op, SelectionDAG &DAG) const {
+ EVT VT = Op.getValueType();
+
+ // Split vector operands.
+ if (VT.isVector())
+ return splitBinaryVectorOp(Op, DAG);
+
+ assert(VT == MVT::i64 && "The following code is a special for s_mul_u64");
+
+ // There are four ways to lower s_mul_u64:
+ //
+ // 1. If all the operands are uniform, then we lower it as it is.
+ //
+ // 2. If the operands are divergent, then we have to split s_mul_u64 in 32-bit
+ // multiplications because there is not a vector equivalent of s_mul_u64.
+ //
+ // 3. If the cost model decides that it is more efficient to use vector
+ // registers, then we have to split s_mul_u64 in 32-bit multiplications.
+ // This happens in splitScalarSMULU64() in SIInstrInfo.cpp .
+ //
+ // 4. If the cost model decides to use vector registers and both of the
+ // operands are zero-extended/sign-extended from 32-bits, then we split the
+ // s_mul_u64 in two 32-bit multiplications. The problem is that it is not
+ // possible to check if the operands are zero-extended or sign-extended in
+ // SIInstrInfo.cpp. For this reason, here, we replace s_mul_u64 with
+ // s_mul_u64_u32_pseudo if both operands are zero-extended and we replace
+ // s_mul_u64 with s_mul_i64_i32_pseudo if both operands are sign-extended.
+ // If the cost model decides that we have to use vector registers, then
+ // splitScalarSMulPseudo() (in SIInstrInfo.cpp) split s_mul_u64_u32/
+ // s_mul_i64_i32_pseudo in two vector multiplications. If the cost model
+ // decides that we should use scalar registers, then s_mul_u64_u32_pseudo/
+ // s_mul_i64_i32_pseudo is lowered as s_mul_u64 in expandPostRAPseudo() in
+ // SIInstrInfo.cpp .
+
+ if (Op->isDivergent())
+ return SDValue();
+
+ SDValue Op0 = Op.getOperand(0);
+ SDValue Op1 = Op.getOperand(1);
+ // If all the operands are zero-enteted to 32-bits, then we replace s_mul_u64
+ // with s_mul_u64_u32_pseudo. If all the operands are sign-extended to
+ // 32-bits, then we replace s_mul_u64 with s_mul_i64_i32_pseudo.
+ KnownBits Op0KnownBits = DAG.computeKnownBits(Op0);
+ unsigned Op0LeadingZeros = Op0KnownBits.countMinLeadingZeros();
+ KnownBits Op1KnownBits = DAG.computeKnownBits(Op1);
+ unsigned Op1LeadingZeros = Op1KnownBits.countMinLeadingZeros();
+ SDLoc SL(Op);
+ if (Op0LeadingZeros >= 32 && Op1LeadingZeros >= 32)
+ return SDValue(
+ DAG.getMachineNode(AMDGPU::S_MUL_U64_U32_PSEUDO, SL, VT, Op0, Op1), 0);
+ unsigned Op0SignBits = DAG.ComputeNumSignBits(Op0);
+ unsigned Op1SignBits = DAG.ComputeNumSignBits(Op1);
+ if (Op0SignBits >= 33 && Op1SignBits >= 33)
+ return SDValue(
+ DAG.getMachineNode(AMDGPU::S_MUL_I64_I32_PSEUDO, SL, VT, Op0, Op1), 0);
+ // If all the operands are uniform, then we lower s_mul_u64 as it is.
+ return Op;
+}
+
SDValue SITargetLowering::lowerXMULO(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
SDLoc SL(Op);
@@ -6854,8 +6934,8 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
SDLoc SL(Op);
EVT VT = Op.getValueType();
- if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v4bf16 ||
- VT == MVT::v8i16 || VT == MVT::v8f16) {
+ if (VT == MVT::v4i16 || VT == MVT::v4f16 || VT == MVT::v8i16 ||
+ VT == MVT::v8f16 || VT == MVT::v4bf16 || VT == MVT::v8bf16) {
EVT HalfVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
VT.getVectorNumElements() / 2);
MVT HalfIntVT = MVT::getIntegerVT(HalfVT.getSizeInBits());
@@ -6878,7 +6958,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
}
- if (VT == MVT::v16i16 || VT == MVT::v16f16) {
+ if (VT == MVT::v16i16 || VT == MVT::v16f16 || VT == MVT::v16bf16) {
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
VT.getVectorNumElements() / 4);
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -6899,7 +6979,7 @@ SDValue SITargetLowering::lowerBUILD_VECTOR(SDValue Op,
return DAG.getNode(ISD::BITCAST, SL, VT, Blend);
}
- if (VT == MVT::v32i16 || VT == MVT::v32f16) {
+ if (VT == MVT::v32i16 || VT == MVT::v32f16 || VT == MVT::v32bf16) {
EVT QuarterVT = MVT::getVectorVT(VT.getVectorElementType().getSimpleVT(),
VT.getVectorNumElements() / 8);
MVT QuarterIntVT = MVT::getIntegerVT(QuarterVT.getSizeInBits());
@@ -14182,11 +14262,11 @@ SDValue SITargetLowering::PerformDAGCombine(SDNode *N,
EVT VT = N->getValueType(0);
// v2i16 (scalar_to_vector i16:x) -> v2i16 (bitcast (any_extend i16:x))
- if (VT == MVT::v2i16 || VT == MVT::v2f16) {
+ if (VT == MVT::v2i16 || VT == MVT::v2f16 || VT == MVT::v2f16) {
SDLoc SL(N);
SDValue Src = N->getOperand(0);
EVT EltVT = Src.getValueType();
- if (EltVT == MVT::f16)
+ if (EltVT != MVT::i16)
Src = DAG.getNode(ISD::BITCAST, SL, MVT::i16, Src);
SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, SL, MVT::i32, Src);