aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/DXContainerGlobals.cpp7
-rw-r--r--llvm/lib/Target/DirectX/DXIL.td9
-rw-r--r--llvm/lib/Target/DirectX/DXILDataScalarization.cpp68
-rw-r--r--llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp4
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp2
-rw-r--r--llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp66
-rw-r--r--llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp8
7 files changed, 130 insertions, 34 deletions
diff --git a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
index eb4c884..677203d 100644
--- a/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
+++ b/llvm/lib/Target/DirectX/DXContainerGlobals.cpp
@@ -285,6 +285,13 @@ void DXContainerGlobals::addPipelineStateValidationInfo(
PSV.BaseData.NumThreadsX = MMI.EntryPropertyVec[0].NumThreadsX;
PSV.BaseData.NumThreadsY = MMI.EntryPropertyVec[0].NumThreadsY;
PSV.BaseData.NumThreadsZ = MMI.EntryPropertyVec[0].NumThreadsZ;
+ if (MMI.EntryPropertyVec[0].WaveSizeMin) {
+ PSV.BaseData.MinimumWaveLaneCount = MMI.EntryPropertyVec[0].WaveSizeMin;
+ PSV.BaseData.MaximumWaveLaneCount =
+ MMI.EntryPropertyVec[0].WaveSizeMax
+ ? MMI.EntryPropertyVec[0].WaveSizeMax
+ : MMI.EntryPropertyVec[0].WaveSizeMin;
+ }
break;
default:
break;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 7ae500a..67437f6 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1079,6 +1079,15 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
let attributes = [Attributes<DXIL1_0, []>];
}
+def LegacyF16ToF32 : DXILOp<131, legacyF16ToF32> {
+ let Doc = "returns the float16 stored in the low-half of the uint converted "
+ "to a float";
+ let intrinsics = [IntrinSelect<int_dx_legacyf16tof32>];
+ let arguments = [Int32Ty];
+ let result = FloatTy;
+ let stages = [Stages<DXIL1_0, [all_stages]>];
+}
+
def WaveAllBitCount : DXILOp<135, waveAllOp> {
let Doc = "returns the count of bits set to 1 across the wave";
let intrinsics = [IntrinSelect<int_dx_wave_active_countbits>];
diff --git a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
index d507d71..9f1616f 100644
--- a/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
+++ b/llvm/lib/Target/DirectX/DXILDataScalarization.cpp
@@ -304,40 +304,76 @@ bool DataScalarizerVisitor::visitGetElementPtrInst(GetElementPtrInst &GEPI) {
GEPOperator *GOp = cast<GEPOperator>(&GEPI);
Value *PtrOperand = GOp->getPointerOperand();
Type *NewGEPType = GOp->getSourceElementType();
- bool NeedsTransform = false;
// Unwrap GEP ConstantExprs to find the base operand and element type
- while (auto *CE = dyn_cast<ConstantExpr>(PtrOperand)) {
- if (auto *GEPCE = dyn_cast<GEPOperator>(CE)) {
- GOp = GEPCE;
- PtrOperand = GEPCE->getPointerOperand();
- NewGEPType = GEPCE->getSourceElementType();
- } else
- break;
+ while (auto *GEPCE = dyn_cast_or_null<GEPOperator>(
+ dyn_cast<ConstantExpr>(PtrOperand))) {
+ GOp = GEPCE;
+ PtrOperand = GEPCE->getPointerOperand();
+ NewGEPType = GEPCE->getSourceElementType();
}
+ Type *const OrigGEPType = NewGEPType;
+ Value *const OrigOperand = PtrOperand;
+
if (GlobalVariable *NewGlobal = lookupReplacementGlobal(PtrOperand)) {
NewGEPType = NewGlobal->getValueType();
PtrOperand = NewGlobal;
- NeedsTransform = true;
} else if (AllocaInst *Alloca = dyn_cast<AllocaInst>(PtrOperand)) {
Type *AllocatedType = Alloca->getAllocatedType();
if (isa<ArrayType>(AllocatedType) &&
- AllocatedType != GOp->getResultElementType()) {
+ AllocatedType != GOp->getResultElementType())
NewGEPType = AllocatedType;
- NeedsTransform = true;
+ } else
+ return false; // Only GEPs into an alloca or global variable are considered
+
+ // Defer changing i8 GEP types until dxil-flatten-arrays
+ if (OrigGEPType->isIntegerTy(8))
+ NewGEPType = OrigGEPType;
+
+ // If the original type is a "sub-type" of the new type, then ensure the gep
+ // correctly zero-indexes the extra dimensions to keep the offset calculation
+ // correct.
+ // Eg:
+ // i32, [4 x i32] and [8 x [4 x i32]] are sub-types of [8 x [4 x i32]], etc.
+ //
+ // So then:
+ // gep [4 x i32] %idx
+ // -> gep [8 x [4 x i32]], i32 0, i32 %idx
+ // gep i32 %idx
+ // -> gep [8 x [4 x i32]], i32 0, i32 0, i32 %idx
+ uint32_t MissingDims = 0;
+ Type *SubType = NewGEPType;
+
+ // The new type will be in its array version; so match accordingly.
+ Type *const GEPArrType = equivalentArrayTypeFromVector(OrigGEPType);
+
+ while (SubType != GEPArrType) {
+ MissingDims++;
+
+ ArrayType *ArrType = dyn_cast<ArrayType>(SubType);
+ if (!ArrType) {
+ assert(SubType == GEPArrType &&
+ "GEP uses an DXIL invalid sub-type of alloca/global variable");
+ break;
}
+
+ SubType = ArrType->getElementType();
}
+ bool NeedsTransform = OrigOperand != PtrOperand ||
+ OrigGEPType != NewGEPType || MissingDims != 0;
+
if (!NeedsTransform)
return false;
- // Keep scalar GEPs scalar; dxil-flatten-arrays will do flattening later
- if (!isa<ArrayType>(GOp->getSourceElementType()))
- NewGEPType = GOp->getSourceElementType();
-
IRBuilder<> Builder(&GEPI);
- SmallVector<Value *, MaxVecSize> Indices(GOp->indices());
+ SmallVector<Value *, MaxVecSize> Indices;
+
+ for (uint32_t I = 0; I < MissingDims; I++)
+ Indices.push_back(Builder.getInt32(0));
+ llvm::append_range(Indices, GOp->indices());
+
Value *NewGEP = Builder.CreateGEP(NewGEPType, PtrOperand, Indices,
GOp->getName(), GOp->getNoWrapFlags());
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index ebb7c26..e0d2dbd 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -197,6 +197,7 @@ static Value *expand16BitIsNormal(CallInst *Orig) {
static bool isIntrinsicExpansion(Function &F) {
switch (F.getIntrinsicID()) {
+ case Intrinsic::assume:
case Intrinsic::abs:
case Intrinsic::atan2:
case Intrinsic::exp:
@@ -988,6 +989,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
case Intrinsic::abs:
Result = expandAbs(Orig);
break;
+ case Intrinsic::assume:
+ Orig->eraseFromParent();
+ return true;
case Intrinsic::atan2:
Result = expandAtan2Intrinsic(Orig);
break;
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 8720460..e46a393 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -904,8 +904,6 @@ public:
case Intrinsic::dx_resource_casthandle:
// NOTE: llvm.dbg.value is supported as is in DXIL.
case Intrinsic::dbg_value:
- // NOTE: llvm.assume is supported as is in DXIL.
- case Intrinsic::assume:
case Intrinsic::not_intrinsic:
if (F.use_empty())
F.eraseFromParent();
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index cf8b833..e1a472f 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -82,6 +82,7 @@ enum class EntryPropsTag {
ASStateTag,
WaveSize,
EntryRootSig,
+ WaveRange = 23,
};
} // namespace
@@ -177,14 +178,15 @@ getTagValueAsMetadata(EntryPropsTag Tag, uint64_t Value, LLVMContext &Ctx) {
case EntryPropsTag::ASStateTag:
case EntryPropsTag::WaveSize:
case EntryPropsTag::EntryRootSig:
+ case EntryPropsTag::WaveRange:
llvm_unreachable("NYI: Unhandled entry property tag");
}
return MDVals;
}
-static MDTuple *
-getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
+static MDTuple *getEntryPropAsMetadata(Module &M, const EntryProperties &EP,
+ uint64_t EntryShaderFlags,
+ const ModuleMetadataInfo &MMDI) {
SmallVector<Metadata *> MDVals;
LLVMContext &Ctx = EP.Entry->getContext();
if (EntryShaderFlags != 0)
@@ -195,12 +197,13 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
// FIXME: support more props.
// See https://github.com/llvm/llvm-project/issues/57948.
// Add shader kind for lib entries.
- if (ShaderProfile == Triple::EnvironmentType::Library &&
+ if (MMDI.ShaderProfile == Triple::EnvironmentType::Library &&
EP.ShaderStage != Triple::EnvironmentType::Library)
MDVals.append(getTagValueAsMetadata(EntryPropsTag::ShaderKind,
getShaderStage(EP.ShaderStage), Ctx));
if (EP.ShaderStage == Triple::EnvironmentType::Compute) {
+ // Handle mandatory "hlsl.numthreads"
MDVals.emplace_back(ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), static_cast<int>(EntryPropsTag::NumThreads))));
Metadata *NumThreadVals[] = {ConstantAsMetadata::get(ConstantInt::get(
@@ -210,8 +213,48 @@ getEntryPropAsMetadata(const EntryProperties &EP, uint64_t EntryShaderFlags,
ConstantAsMetadata::get(ConstantInt::get(
Type::getInt32Ty(Ctx), EP.NumThreadsZ))};
MDVals.emplace_back(MDNode::get(Ctx, NumThreadVals));
+
+ // Handle optional "hlsl.wavesize". The fields are optionally represented
+ // if they are non-zero.
+ if (EP.WaveSizeMin != 0) {
+ bool IsWaveRange = VersionTuple(6, 8) <= MMDI.ShaderModelVersion;
+ bool IsWaveSize =
+ !IsWaveRange && VersionTuple(6, 6) <= MMDI.ShaderModelVersion;
+
+ if (!IsWaveRange && !IsWaveSize) {
+ reportError(M, "Shader model 6.6 or greater is required to specify "
+ "the \"hlsl.wavesize\" function attribute");
+ return nullptr;
+ }
+
+ // A range is being specified if EP.WaveSizeMax != 0
+ if (EP.WaveSizeMax && !IsWaveRange) {
+ reportError(
+ M, "Shader model 6.8 or greater is required to specify "
+ "wave size range values of the \"hlsl.wavesize\" function "
+ "attribute");
+ return nullptr;
+ }
+
+ EntryPropsTag Tag =
+ IsWaveSize ? EntryPropsTag::WaveSize : EntryPropsTag::WaveRange;
+ MDVals.emplace_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), static_cast<int>(Tag))));
+
+ SmallVector<Metadata *> WaveSizeVals = {ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMin))};
+ if (IsWaveRange) {
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizeMax)));
+ WaveSizeVals.push_back(ConstantAsMetadata::get(
+ ConstantInt::get(Type::getInt32Ty(Ctx), EP.WaveSizePref)));
+ }
+
+ MDVals.emplace_back(MDNode::get(Ctx, WaveSizeVals));
+ }
}
}
+
if (MDVals.empty())
return nullptr;
return MDNode::get(Ctx, MDVals);
@@ -236,12 +279,11 @@ static MDTuple *constructEntryMetadata(const Function *EntryFn,
return MDNode::get(Ctx, MDVals);
}
-static MDTuple *emitEntryMD(const EntryProperties &EP, MDTuple *Signatures,
- MDNode *MDResources,
+static MDTuple *emitEntryMD(Module &M, const EntryProperties &EP,
+ MDTuple *Signatures, MDNode *MDResources,
const uint64_t EntryShaderFlags,
- const Triple::EnvironmentType ShaderProfile) {
- MDTuple *Properties =
- getEntryPropAsMetadata(EP, EntryShaderFlags, ShaderProfile);
+ const ModuleMetadataInfo &MMDI) {
+ MDTuple *Properties = getEntryPropAsMetadata(M, EP, EntryShaderFlags, MMDI);
return constructEntryMetadata(EP.Entry, Signatures, MDResources, Properties,
EP.Entry->getContext());
}
@@ -523,10 +565,8 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
"'"));
}
-
- EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
- EntryShaderFlags,
- MMDI.ShaderProfile));
+ EntryFnMDNodes.emplace_back(emitEntryMD(
+ M, EntryProp, Signatures, ResourceMD, EntryShaderFlags, MMDI));
}
NamedMDNode *EntryPointsNamedMD =
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 60dfd96..6cacbf6 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -29,11 +29,12 @@ bool DirectXTTIImpl::isTargetIntrinsicWithOverloadTypeAtArg(Intrinsic::ID ID,
int OpdIdx) const {
switch (ID) {
case Intrinsic::dx_asdouble:
- case Intrinsic::dx_isinf:
- case Intrinsic::dx_isnan:
case Intrinsic::dx_firstbitlow:
- case Intrinsic::dx_firstbituhigh:
case Intrinsic::dx_firstbitshigh:
+ case Intrinsic::dx_firstbituhigh:
+ case Intrinsic::dx_isinf:
+ case Intrinsic::dx_isnan:
+ case Intrinsic::dx_legacyf16tof32:
return OpdIdx == 0;
default:
return OpdIdx == -1;
@@ -50,6 +51,7 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_frac:
case Intrinsic::dx_isinf:
case Intrinsic::dx_isnan:
+ case Intrinsic::dx_legacyf16tof32:
case Intrinsic::dx_rsqrt:
case Intrinsic::dx_saturate:
case Intrinsic::dx_splitdouble: