aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target')
-rw-r--r--llvm/lib/Target/AArch64/AArch64ISelLowering.cpp200
-rw-r--r--llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp64
-rw-r--r--llvm/lib/Target/AArch64/AArch64Processors.td11
-rw-r--r--llvm/lib/Target/AVR/AVRISelLowering.cpp7
-rw-r--r--llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp2
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp12
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXIntrinsics.td62
-rw-r--r--llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp8
-rw-r--r--llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td163
9 files changed, 323 insertions, 206 deletions
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index a40de86b..3c06c6a 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -14742,6 +14742,106 @@ static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
return ResultSLI;
}
+static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
+ const AArch64TargetLowering &TLI) {
+ EVT VT = N->getValueType(0);
+ SelectionDAG &DAG = DCI.DAG;
+ SDLoc DL(N);
+ const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
+
+ if (!VT.isVector())
+ return SDValue();
+
+ if (VT.isScalableVector() && !Subtarget.hasSVE2())
+ return SDValue();
+
+ if (VT.isFixedLengthVector() &&
+ (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT)))
+ return SDValue();
+
+ SDValue N0 = N->getOperand(0);
+ if (N0.getOpcode() != ISD::AND)
+ return SDValue();
+
+ SDValue N1 = N->getOperand(1);
+ if (N1.getOpcode() != ISD::AND)
+ return SDValue();
+
+ // InstCombine does (not (neg a)) => (add a -1).
+ // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
+ // Loop over all combinations of AND operands.
+ for (int i = 1; i >= 0; --i) {
+ for (int j = 1; j >= 0; --j) {
+ SDValue O0 = N0->getOperand(i);
+ SDValue O1 = N1->getOperand(j);
+ SDValue Sub, Add, SubSibling, AddSibling;
+
+ // Find a SUB and an ADD operand, one from each AND.
+ if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
+ Sub = O0;
+ Add = O1;
+ SubSibling = N0->getOperand(1 - i);
+ AddSibling = N1->getOperand(1 - j);
+ } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
+ Add = O0;
+ Sub = O1;
+ AddSibling = N0->getOperand(1 - i);
+ SubSibling = N1->getOperand(1 - j);
+ } else
+ continue;
+
+ if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
+ continue;
+
+ // Constant ones is always righthand operand of the Add.
+ if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
+ continue;
+
+ if (Sub.getOperand(1) != Add.getOperand(0))
+ continue;
+
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
+ }
+ }
+
+ // (or (and a b) (and (not a) c)) => (bsl a b c)
+ // We only have to look for constant vectors here since the general, variable
+ // case can be handled in TableGen.
+ unsigned Bits = VT.getScalarSizeInBits();
+ uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
+ for (int i = 1; i >= 0; --i)
+ for (int j = 1; j >= 0; --j) {
+ APInt Val1, Val2;
+
+ if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
+ ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
+ (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) {
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
+ N0->getOperand(1 - i), N1->getOperand(1 - j));
+ }
+ BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
+ BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
+ if (!BVN0 || !BVN1)
+ continue;
+
+ bool FoundMatch = true;
+ for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
+ ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
+ ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
+ if (!CN0 || !CN1 ||
+ CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
+ FoundMatch = false;
+ break;
+ }
+ }
+ if (FoundMatch)
+ return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
+ N0->getOperand(1 - i), N1->getOperand(1 - j));
+ }
+
+ return SDValue();
+}
+
SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
SelectionDAG &DAG) const {
if (useSVEForFixedLengthVectorVT(Op.getValueType(),
@@ -19419,106 +19519,6 @@ static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
return FixConv;
}
-static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
- const AArch64TargetLowering &TLI) {
- EVT VT = N->getValueType(0);
- SelectionDAG &DAG = DCI.DAG;
- SDLoc DL(N);
- const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
-
- if (!VT.isVector())
- return SDValue();
-
- if (VT.isScalableVector() && !Subtarget.hasSVE2())
- return SDValue();
-
- if (VT.isFixedLengthVector() &&
- (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT)))
- return SDValue();
-
- SDValue N0 = N->getOperand(0);
- if (N0.getOpcode() != ISD::AND)
- return SDValue();
-
- SDValue N1 = N->getOperand(1);
- if (N1.getOpcode() != ISD::AND)
- return SDValue();
-
- // InstCombine does (not (neg a)) => (add a -1).
- // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
- // Loop over all combinations of AND operands.
- for (int i = 1; i >= 0; --i) {
- for (int j = 1; j >= 0; --j) {
- SDValue O0 = N0->getOperand(i);
- SDValue O1 = N1->getOperand(j);
- SDValue Sub, Add, SubSibling, AddSibling;
-
- // Find a SUB and an ADD operand, one from each AND.
- if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
- Sub = O0;
- Add = O1;
- SubSibling = N0->getOperand(1 - i);
- AddSibling = N1->getOperand(1 - j);
- } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
- Add = O0;
- Sub = O1;
- AddSibling = N0->getOperand(1 - i);
- SubSibling = N1->getOperand(1 - j);
- } else
- continue;
-
- if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
- continue;
-
- // Constant ones is always righthand operand of the Add.
- if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
- continue;
-
- if (Sub.getOperand(1) != Add.getOperand(0))
- continue;
-
- return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
- }
- }
-
- // (or (and a b) (and (not a) c)) => (bsl a b c)
- // We only have to look for constant vectors here since the general, variable
- // case can be handled in TableGen.
- unsigned Bits = VT.getScalarSizeInBits();
- uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
- for (int i = 1; i >= 0; --i)
- for (int j = 1; j >= 0; --j) {
- APInt Val1, Val2;
-
- if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
- ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
- (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) {
- return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
- N0->getOperand(1 - i), N1->getOperand(1 - j));
- }
- BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
- BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
- if (!BVN0 || !BVN1)
- continue;
-
- bool FoundMatch = true;
- for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
- ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
- ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
- if (!CN0 || !CN1 ||
- CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
- FoundMatch = false;
- break;
- }
- }
- if (FoundMatch)
- return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
- N0->getOperand(1 - i), N1->getOperand(1 - j));
- }
-
- return SDValue();
-}
-
// Given a tree of and/or(csel(0, 1, cc0), csel(0, 1, cc1)), we may be able to
// convert to csel(ccmp(.., cc0)), depending on cc1:
diff --git a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
index b97d622..fd4ef2a 100644
--- a/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
+++ b/llvm/lib/Target/AArch64/AArch64MIPeepholeOpt.cpp
@@ -8,8 +8,8 @@
//
// This pass performs below peephole optimizations on MIR level.
//
-// 1. MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri
-// MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri
+// 1. MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri
+// MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri
//
// 2. MOVi32imm + ADDWrr ==> ADDWRi + ADDWRi
// MOVi64imm + ADDXrr ==> ADDXri + ADDXri
@@ -128,6 +128,7 @@ struct AArch64MIPeepholeOpt : public MachineFunctionPass {
// Strategy used to split logical immediate bitmasks.
enum class SplitStrategy {
Intersect,
+ Disjoint,
};
template <typename T>
bool trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
@@ -163,6 +164,7 @@ INITIALIZE_PASS(AArch64MIPeepholeOpt, "aarch64-mi-peephole-opt",
template <typename T>
static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) {
T UImm = static_cast<T>(Imm);
+ assert(UImm && (UImm != ~static_cast<T>(0)) && "Invalid immediate!");
// The bitmask immediate consists of consecutive ones. Let's say there is
// constant 0b00000000001000000000010000000000 which does not consist of
@@ -191,18 +193,47 @@ static bool splitBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc, T &Imm2Enc) {
}
template <typename T>
+static bool splitDisjointBitmaskImm(T Imm, unsigned RegSize, T &Imm1Enc,
+ T &Imm2Enc) {
+ assert(Imm && (Imm != ~static_cast<T>(0)) && "Invalid immediate!");
+
+ // Try to split a bitmask of the form 0b00000000011000000000011110000000 into
+ // two disjoint masks such as 0b00000000011000000000000000000000 and
+ // 0b00000000000000000000011110000000 where the inclusive/exclusive OR of the
+ // new masks match the original mask.
+ unsigned LowestBitSet = llvm::countr_zero(Imm);
+ unsigned LowestGapBitUnset =
+ LowestBitSet + llvm::countr_one(Imm >> LowestBitSet);
+
+ // Create a mask for the least significant group of consecutive ones.
+ assert(LowestGapBitUnset < sizeof(T) * CHAR_BIT && "Undefined behaviour!");
+ T NewImm1 = (static_cast<T>(1) << LowestGapBitUnset) -
+ (static_cast<T>(1) << LowestBitSet);
+ // Create a disjoint mask for the remaining ones.
+ T NewImm2 = Imm & ~NewImm1;
+
+ // Do not split if NewImm2 is not a valid bitmask immediate.
+ if (!AArch64_AM::isLogicalImmediate(NewImm2, RegSize))
+ return false;
+
+ Imm1Enc = AArch64_AM::encodeLogicalImmediate(NewImm1, RegSize);
+ Imm2Enc = AArch64_AM::encodeLogicalImmediate(NewImm2, RegSize);
+ return true;
+}
+
+template <typename T>
bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
SplitStrategy Strategy,
unsigned OtherOpc) {
- // Try below transformation.
+ // Try below transformations.
//
- // MOVi32imm + ANDS?Wrr ==> ANDWri + ANDS?Wri
- // MOVi64imm + ANDS?Xrr ==> ANDXri + ANDS?Xri
+ // MOVi32imm + (ANDS?|EOR|ORR)Wrr ==> (AND|EOR|ORR)Wri + (ANDS?|EOR|ORR)Wri
+ // MOVi64imm + (ANDS?|EOR|ORR)Xrr ==> (AND|EOR|ORR)Xri + (ANDS?|EOR|ORR)Xri
//
// The mov pseudo instruction could be expanded to multiple mov instructions
// later. Let's try to split the constant operand of mov instruction into two
- // bitmask immediates. It makes only two AND instructions instead of multiple
- // mov + and instructions.
+ // bitmask immediates based on the given split strategy. It makes only two
+ // logical instructions instead of multiple mov + logic instructions.
return splitTwoPartImm<T>(
MI,
@@ -224,6 +255,9 @@ bool AArch64MIPeepholeOpt::trySplitLogicalImm(unsigned Opc, MachineInstr &MI,
case SplitStrategy::Intersect:
SplitSucc = splitBitmaskImm(Imm, RegSize, Imm0, Imm1);
break;
+ case SplitStrategy::Disjoint:
+ SplitSucc = splitDisjointBitmaskImm(Imm, RegSize, Imm0, Imm1);
+ break;
}
if (SplitSucc)
return std::make_pair(Opc, !OtherOpc ? Opc : OtherOpc);
@@ -889,6 +923,22 @@ bool AArch64MIPeepholeOpt::runOnMachineFunction(MachineFunction &MF) {
Changed |= trySplitLogicalImm<uint64_t>(
AArch64::ANDXri, MI, SplitStrategy::Intersect, AArch64::ANDSXri);
break;
+ case AArch64::EORWrr:
+ Changed |= trySplitLogicalImm<uint32_t>(AArch64::EORWri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::EORXrr:
+ Changed |= trySplitLogicalImm<uint64_t>(AArch64::EORXri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::ORRWrr:
+ Changed |= trySplitLogicalImm<uint32_t>(AArch64::ORRWri, MI,
+ SplitStrategy::Disjoint);
+ break;
+ case AArch64::ORRXrr:
+ Changed |= trySplitLogicalImm<uint64_t>(AArch64::ORRXri, MI,
+ SplitStrategy::Disjoint);
+ break;
case AArch64::ORRWrs:
Changed |= visitORR(MI);
break;
diff --git a/llvm/lib/Target/AArch64/AArch64Processors.td b/llvm/lib/Target/AArch64/AArch64Processors.td
index adc984a..1bc1d98 100644
--- a/llvm/lib/Target/AArch64/AArch64Processors.td
+++ b/llvm/lib/Target/AArch64/AArch64Processors.td
@@ -22,7 +22,8 @@ def TuneA320 : SubtargetFeature<"a320", "ARMProcFamily", "CortexA320",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove]>;
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost]>;
def TuneA53 : SubtargetFeature<"a53", "ARMProcFamily", "CortexA53",
"Cortex-A53 ARM processors", [
@@ -45,7 +46,8 @@ def TuneA510 : SubtargetFeature<"a510", "ARMProcFamily", "CortexA510",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost
]>;
def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520",
@@ -53,7 +55,8 @@ def TuneA520 : SubtargetFeature<"a520", "ARMProcFamily", "CortexA520",
FeatureFuseAES,
FeatureFuseAdrpAdd,
FeaturePostRAScheduler,
- FeatureUseWzrToVecMove]>;
+ FeatureUseWzrToVecMove,
+ FeatureUseFixedOverScalableIfEqualCost]>;
def TuneA520AE : SubtargetFeature<"a520ae", "ARMProcFamily", "CortexA520",
"Cortex-A520AE ARM processors", [
@@ -756,7 +759,6 @@ def ProcessorFeatures {
FeatureSB, FeaturePAuth, FeatureSSBS, FeatureSVE, FeatureSVE2,
FeatureComplxNum, FeatureCRC, FeatureDotProd,
FeatureFPARMv8,FeatureFullFP16, FeatureJS, FeatureLSE,
- FeatureUseFixedOverScalableIfEqualCost,
FeatureRAS, FeatureRCPC, FeatureRDM, FeatureFPAC];
list<SubtargetFeature> A520 = [HasV9_2aOps, FeaturePerfMon, FeatureAM,
FeatureMTE, FeatureETE, FeatureSVEBitPerm,
@@ -766,7 +768,6 @@ def ProcessorFeatures {
FeatureSVE, FeatureSVE2, FeatureBF16, FeatureComplxNum, FeatureCRC,
FeatureFPARMv8, FeatureFullFP16, FeatureMatMulInt8, FeatureJS,
FeatureNEON, FeatureLSE, FeatureRAS, FeatureRCPC, FeatureRDM,
- FeatureUseFixedOverScalableIfEqualCost,
FeatureDotProd, FeatureFPAC];
list<SubtargetFeature> A520AE = [HasV9_2aOps, FeaturePerfMon, FeatureAM,
FeatureMTE, FeatureETE, FeatureSVEBitPerm,
diff --git a/llvm/lib/Target/AVR/AVRISelLowering.cpp b/llvm/lib/Target/AVR/AVRISelLowering.cpp
index 3955f2a..25ad9ec 100644
--- a/llvm/lib/Target/AVR/AVRISelLowering.cpp
+++ b/llvm/lib/Target/AVR/AVRISelLowering.cpp
@@ -669,7 +669,7 @@ SDValue AVRTargetLowering::getAVRCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
default: {
// Turn lhs < rhs with lhs constant into rhs >= lhs+1, this allows
// us to fold the constant into the cmp instruction.
- RHS = DAG.getConstant(C->getSExtValue() + 1, DL, VT);
+ RHS = DAG.getSignedConstant(C->getSExtValue() + 1, DL, VT);
CC = ISD::SETGE;
break;
}
@@ -713,7 +713,10 @@ SDValue AVRTargetLowering::getAVRCmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
// Turn lhs < rhs with lhs constant into rhs >= lhs+1, this allows us to
// fold the constant into the cmp instruction.
if (const ConstantSDNode *C = dyn_cast<ConstantSDNode>(RHS)) {
- RHS = DAG.getConstant(C->getSExtValue() + 1, DL, VT);
+ // Doing a "icmp ugt i16 65535, %0" comparison should have been converted
+ // already to something else. Assert to make sure this assumption holds.
+ assert((!C->isAllOnes()) && "integer overflow in comparison transform");
+ RHS = DAG.getConstant(C->getZExtValue() + 1, DL, VT);
CC = ISD::SETUGE;
break;
}
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index ffd900c..5153d24 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -56,6 +56,8 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
case Intrinsic::dx_wave_reduce_usum:
+ case Intrinsic::dx_imad:
+ case Intrinsic::dx_umad:
return true;
default:
return false;
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 9003ace..d4f0cc9 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -4046,6 +4046,18 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
return true;
}
+ case Intrinsic::nvvm_prefetch_tensormap: {
+ auto &DL = I.getDataLayout();
+ Info.opc = ISD::INTRINSIC_VOID;
+ Info.memVT = getPointerTy(DL);
+ Info.ptrVal = I.getArgOperand(0);
+ Info.offset = 0;
+ Info.flags =
+ MachineMemOperand::MOLoad | MachineMemOperand::MODereferenceable;
+ Info.align.reset();
+ return true;
+ }
+
case Intrinsic::nvvm_ldu_global_i:
case Intrinsic::nvvm_ldu_global_f:
case Intrinsic::nvvm_ldu_global_p: {
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index d337192..d4a0ca7 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -39,6 +39,12 @@ def AS_match {
code global = [{
return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_GLOBAL);
}];
+ code const = [{
+ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_CONST);
+ }];
+ code param = [{
+ return ChkMemSDNodeAddressSpace(N, llvm::ADDRESS_SPACE_PARAM);
+ }];
}
@@ -950,33 +956,47 @@ foreach dim = 3...5 in {
defm TMA_TENSOR_PF_TILE_GATHER4_2D : TMA_TENSOR_PREFETCH_INTR<5, "tile_gather4",
[hasTMACTAGroupSupport]>;
-//Prefetch and Prefetchu
-
-let Predicates = [hasPTX<80>, hasSM<90>] in {
- class PREFETCH_INTRS<string InstName> :
- BasicNVPTXInst<(outs), (ins ADDR:$addr),
- InstName,
- [(!cast<Intrinsic>(!strconcat("int_nvvm_",
- !subst(".", "_", InstName))) addr:$addr)]>;
+//Prefetchu and Prefetch
- def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1">;
- def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2">;
- def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1">;
- def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1">;
- def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2">;
- def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2">;
+defvar frag_pat = (int_nvvm_prefetch_tensormap node:$addr);
- def PREFETCH_GLOBAL_L2_EVICT_NORMAL : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_normal",
- [(int_nvvm_prefetch_global_L2_evict_normal addr:$addr)]>;
+multiclass PREFETCH_TENSORMAP_PATFRAG<string suffix, code predicate> {
+ def !tolower(suffix) : PatFrag<!setdagop(frag_pat, ops), frag_pat, predicate>;
+}
- def PREFETCH_GLOBAL_L2_EVICT_LAST : BasicNVPTXInst<(outs), (ins ADDR:$addr),
- "prefetch.global.L2::evict_last",
- [(int_nvvm_prefetch_global_L2_evict_last addr:$addr)]>;
+defm prefetch_tensormap_ : PREFETCH_TENSORMAP_PATFRAG<"CONST", AS_match.const>;
+defm prefetch_tensormap_ : PREFETCH_TENSORMAP_PATFRAG<"GENERIC", AS_match.generic>;
+defm prefetch_tensormap_ : PREFETCH_TENSORMAP_PATFRAG<"PARAM", AS_match.param>;
- def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1">;
+multiclass PREFETCH_TENSORMAP_INST<string addrspace_name, PatFrag pattern_frag> {
+ def "" : BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ "prefetch" # addrspace_name # ".tensormap",
+ [(pattern_frag addr:$addr)]>,
+ Requires<[hasPTX<80>, hasSM<90>]>;
}
+defm PREFETCH_CONST_TENSORMAP : PREFETCH_TENSORMAP_INST<".const", prefetch_tensormap_const>;
+defm PREFETCH_GENERIC_TENSORMAP : PREFETCH_TENSORMAP_INST<"", prefetch_tensormap_generic>;
+defm PREFETCH_PARAM_TENSORMAP : PREFETCH_TENSORMAP_INST<".param", prefetch_tensormap_param>;
+
+class PREFETCH_INTRS<string InstName, Intrinsic Intr> :
+ BasicNVPTXInst<(outs), (ins ADDR:$addr),
+ InstName,
+ [(Intr addr:$addr)]>,
+ Requires<[hasPTX<80>, hasSM<90>]>;
+
+def PREFETCHU_L1 : PREFETCH_INTRS<"prefetchu.L1", int_nvvm_prefetchu_L1>;
+def PREFETCH_L1 : PREFETCH_INTRS<"prefetch.L1", int_nvvm_prefetch_L1>;
+def PREFETCH_L2 : PREFETCH_INTRS<"prefetch.L2", int_nvvm_prefetch_L2>;
+def PREFETCH_GLOBAL_L1 : PREFETCH_INTRS<"prefetch.global.L1", int_nvvm_prefetch_global_L1>;
+def PREFETCH_LOCAL_L1 : PREFETCH_INTRS<"prefetch.local.L1", int_nvvm_prefetch_local_L1>;
+def PREFETCH_GLOBAL_L2 : PREFETCH_INTRS<"prefetch.global.L2", int_nvvm_prefetch_global_L2>;
+def PREFETCH_LOCAL_L2 : PREFETCH_INTRS<"prefetch.local.L2", int_nvvm_prefetch_local_L2>;
+def PREFETCH_GLOBAL_L2_EVICT_NORMAL : PREFETCH_INTRS<"prefetch.global.L2::evict_normal",
+ int_nvvm_prefetch_global_L2_evict_normal>;
+def PREFETCH_GLOBAL_L2_EVICT_LAST : PREFETCH_INTRS<"prefetch.global.L2::evict_last",
+ int_nvvm_prefetch_global_L2_evict_last>;
+
//Applypriority intrinsics
class APPLYPRIORITY_L2_INTRS<string addrspace> :
BasicNVPTXInst<(outs), (ins ADDR:$addr, B64:$size),
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 3ae2d9d..f4f8961 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -564,7 +564,8 @@ bool NVPTXTTIImpl::collectFlatAddressOperands(SmallVectorImpl<int> &OpIndexes,
case Intrinsic::nvvm_isspacep_global:
case Intrinsic::nvvm_isspacep_local:
case Intrinsic::nvvm_isspacep_shared:
- case Intrinsic::nvvm_isspacep_shared_cluster: {
+ case Intrinsic::nvvm_isspacep_shared_cluster:
+ case Intrinsic::nvvm_prefetch_tensormap: {
OpIndexes.push_back(0);
return true;
}
@@ -587,6 +588,11 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II,
return ConstantInt::get(II->getType(), *R);
return nullptr;
}
+ case Intrinsic::nvvm_prefetch_tensormap: {
+ IRBuilder<> Builder(II);
+ return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_prefetch_tensormap,
+ NewV);
+ }
}
return nullptr;
}
diff --git a/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td b/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td
index bf23812..5541506 100644
--- a/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td
+++ b/llvm/lib/Target/RISCV/RISCVSchedSpacemitX60.td
@@ -13,78 +13,113 @@
//
//===----------------------------------------------------------------------===//
-class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
- string LLMUL = LargestLMUL<MxList>.r;
- bit c = !eq(mx, LLMUL);
-}
+//===----------------------------------------------------------------------===//
+// Helpers
+
+// Maps LMUL string to corresponding value from the Values array
+// LMUL values map to array indices as follows:
+// MF8 -> Values[0], MF4 -> Values[1], MF2 -> Values[2], M1 -> Values[3],
+// M2 -> Values[4], M4 -> Values[5], M8 -> Values[6]
+// Shorter lists are allowed, e.g., widening instructions don't work on M8
+class GetLMULValue<list<int> Values, string LMUL> {
+ defvar Index = !cond(
+ !eq(LMUL, "MF8"): 0,
+ !eq(LMUL, "MF4"): 1,
+ !eq(LMUL, "MF2"): 2,
+ !eq(LMUL, "M1"): 3,
+ !eq(LMUL, "M2"): 4,
+ !eq(LMUL, "M4"): 5,
+ !eq(LMUL, "M8"): 6,
+ );
-class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
- string LLMUL = LargestLMUL<MxList>.r;
- int SSEW = SmallestSEW<mx, isF>.r;
- bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
+ assert !lt(Index, !size(Values)),
+ "Missing LMUL value for '" # LMUL # "'. " #
+ "Expected at least " # !add(Index, 1) # " elements, but got " #
+ !size(Values) # ".";
+
+ int c = Values[Index];
}
-defvar SMX60VLEN = 256;
-defvar SMX60DLEN = !div(SMX60VLEN, 2);
+// Returns BaseValue for LMUL values before startLMUL, Value for startLMUL,
+// then doubles Value for each subsequent LMUL
+// Example: ConstValueUntilLMULThenDoubleBase<"M1", 2, 4, "M8"> returns:
+// MF8->2, MF4->2, MF2->2, M1->4, M2->8, M4->16, M8->32
+// This is useful for modeling scheduling parameters that scale with LMUL.
+class ConstValueUntilLMULThenDoubleBase<string startLMUL, int BaseValue, int Value, string currentLMUL> {
+ assert !le(BaseValue, Value), "BaseValue must be less-equal to Value";
+ defvar startPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], startLMUL>.c;
+ defvar currentPos = GetLMULValue<[0, 1, 2, 3, 4, 5, 6], currentLMUL>.c;
-class Get1248Latency<string mx> {
+ // Calculate the difference in positions
+ defvar posDiff = !sub(currentPos, startPos);
+
+ // Calculate Value * (2^posDiff)
int c = !cond(
- !eq(mx, "M2") : 2,
- !eq(mx, "M4") : 4,
- !eq(mx, "M8") : 8,
- true: 1
+ !eq(posDiff, 0) : Value,
+ !eq(posDiff, 1) : !mul(Value, 2),
+ !eq(posDiff, 2) : !mul(Value, 4),
+ !eq(posDiff, 3) : !mul(Value, 8),
+ !eq(posDiff, 4) : !mul(Value, 16),
+ !eq(posDiff, 5) : !mul(Value, 32),
+ !eq(posDiff, 6) : !mul(Value, 64),
+ true : BaseValue
);
}
-// Used for: logical opsz, shifts, sign ext, merge/move, FP sign/recip/convert, mask ops, slides
-class Get4816Latency<string mx> {
- int c = !cond(
- !eq(mx, "M4") : 8,
- !eq(mx, "M8") : 16,
- true: 4
- );
+// Same as the previous function but BaseValue == Value
+class ConstValueUntilLMULThenDouble<string startLMUL, int Value, string currentLMUL> {
+ int c = ConstValueUntilLMULThenDoubleBase<startLMUL, Value, Value, currentLMUL>.c;
+}
+
+// Returns MF8->1, MF4->1, MF2->2, M1->4, M2->8, M4->16, M8->32
+class ConstOneUntilMF4ThenDouble<string mx> {
+ int c = ConstValueUntilLMULThenDouble<"MF4", 1, mx>.c;
+}
+
+// Returns MF8->1, MF4->1, MF2->1, M1->2, M2->4, M4->8, M8->16
+class ConstOneUntilMF2ThenDouble<string mx> {
+ int c = ConstValueUntilLMULThenDouble<"MF2", 1, mx>.c;
+}
+
+// Returns MF8->1, MF4->1, MF2->1, M1->1, M2->2, M4->4, M8->8
+class ConstOneUntilM1ThenDouble<string mx> {
+ int c = ConstValueUntilLMULThenDouble<"M1", 1, mx>.c;
}
+//===----------------------------------------------------------------------===//
+// Latency helper classes
+
// Used for: arithmetic (add/sub/min/max), saturating/averaging, FP add/sub/min/max
-class Get458Latency<string mx> {
- int c = !cond(
- !eq(mx, "M4") : 5,
- !eq(mx, "M8") : 8,
- true: 4
- );
+class Get4458Latency<string mx> {
+ int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/4, /*M4=*/5, /*M8=*/8], mx>.c;
}
-// Widening scaling pattern (4,4,4,4,5,8,8): plateaus at higher LMULs
-// Used for: widening operations
+// Used for: widening operations (no M8)
class Get4588Latency<string mx> {
- int c = !cond(
- !eq(mx, "M2") : 5,
- !eq(mx, "M4") : 8,
- !eq(mx, "M8") : 8, // M8 not supported for most widening, fallback
- true: 4
- );
+ int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/5, /*M4=*/8], mx>.c;
}
// Used for: mask-producing comparisons, carry ops with mask, FP comparisons
class Get461018Latency<string mx> {
- int c = !cond(
- !eq(mx, "M2") : 6,
- !eq(mx, "M4") : 10,
- !eq(mx, "M8") : 18,
- true: 4
- );
+ int c = GetLMULValue<[/*MF8=*/4, /*MF4=*/4, /*MF2=*/4, /*M1=*/4, /*M2=*/6, /*M4=*/10, /*M8=*/18], mx>.c;
}
-// Used for: e64 multiply pattern, complex ops
-class Get781632Latency<string mx> {
- int c = !cond(
- !eq(mx, "M2") : 8,
- !eq(mx, "M4") : 16,
- !eq(mx, "M8") : 32,
- true: 7
- );
+//===----------------------------------------------------------------------===//
+
+class SMX60IsWorstCaseMX<string mx, list<string> MxList> {
+ string LLMUL = LargestLMUL<MxList>.r;
+ bit c = !eq(mx, LLMUL);
}
+class SMX60IsWorstCaseMXSEW<string mx, int sew, list<string> MxList, bit isF = 0> {
+ string LLMUL = LargestLMUL<MxList>.r;
+ int SSEW = SmallestSEW<mx, isF>.r;
+ bit c = !and(!eq(mx, LLMUL), !eq(sew, SSEW));
+}
+
+defvar SMX60VLEN = 256;
+defvar SMX60DLEN = !div(SMX60VLEN, 2);
+
def SpacemitX60Model : SchedMachineModel {
let IssueWidth = 2; // dual-issue
let MicroOpBufferSize = 0; // in-order
@@ -383,12 +418,13 @@ foreach LMul = [1, 2, 4, 8] in {
foreach mx = SchedMxList in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxList>.c;
- let Latency = Get458Latency<mx>.c, ReleaseAtCycles = [4] in {
+ let Latency = Get4458Latency<mx>.c, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVIMinMaxV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMinMaxX", [SMX60_VIEU], mx, IsWorstCase>;
}
- let Latency = Get4816Latency<mx>.c, ReleaseAtCycles = [4] in {
+ defvar VIALULat = ConstValueUntilLMULThenDouble<"M2", 4, mx>.c;
+ let Latency = VIALULat, ReleaseAtCycles = [4] in {
// Pattern of vadd, vsub, vrsub: 4/4/5/8
// Pattern of vand, vor, vxor: 4/4/8/16
// They are grouped together, so we used the worst case 4/4/8/16
@@ -425,7 +461,7 @@ foreach mx = SchedMxList in {
// Pattern of vmacc, vmadd, vmul, vmulh, etc.: e8/e16 = 4/4/5/8, e32 = 5,5,5,8,
// e64 = 7,8,16,32. We use the worst-case until we can split the SEW.
// TODO: change WriteVIMulV, etc to be defined with LMULSEWSchedWrites
- let Latency = Get781632Latency<mx>.c, ReleaseAtCycles = [7] in {
+ let Latency = ConstValueUntilLMULThenDoubleBase<"M2", 7, 8, mx>.c, ReleaseAtCycles = [7] in {
defm "" : LMULWriteResMX<"WriteVIMulV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVIMulAddV", [SMX60_VIEU], mx, IsWorstCase>;
@@ -461,15 +497,8 @@ foreach mx = SchedMxList in {
foreach sew = SchedSEWSet<mx>.val in {
defvar IsWorstCase = SMX60IsWorstCaseMXSEW<mx, sew, SchedMxList>.c;
- // Slightly reduced for fractional LMULs
- defvar Multiplier = !cond(
- !eq(mx, "MF8") : 12,
- !eq(mx, "MF4") : 12,
- !eq(mx, "MF2") : 12,
- true: 24
- );
-
- let Latency = !mul(Get1248Latency<mx>.c, Multiplier), ReleaseAtCycles = [12] in {
+ defvar VIDivLat = ConstValueUntilLMULThenDouble<"MF2", 12, mx>.c;
+ let Latency = VIDivLat, ReleaseAtCycles = [12] in {
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivV", [SMX60_VIEU], mx, sew, IsWorstCase>;
defm "" : LMULSEWWriteResMXSEW<"WriteVIDivX", [SMX60_VIEU], mx, sew, IsWorstCase>;
}
@@ -480,14 +509,8 @@ foreach mx = SchedMxList in {
foreach mx = SchedMxListW in {
defvar IsWorstCase = SMX60IsWorstCaseMX<mx, SchedMxListW>.c;
- // Slightly increased for integer LMULs
- defvar Multiplier = !cond(
- !eq(mx, "M2") : 2,
- !eq(mx, "M4") : 2,
- true: 1
- );
-
- let Latency = !mul(Get4816Latency<mx>.c, Multiplier), ReleaseAtCycles = [4] in {
+ defvar VNarrowingLat = ConstValueUntilLMULThenDouble<"M1", 4, mx>.c;
+ let Latency = VNarrowingLat, ReleaseAtCycles = [4] in {
defm "" : LMULWriteResMX<"WriteVNShiftV", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftX", [SMX60_VIEU], mx, IsWorstCase>;
defm "" : LMULWriteResMX<"WriteVNShiftI", [SMX60_VIEU], mx, IsWorstCase>;