aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX')
-rw-r--r--llvm/lib/Target/DirectX/DXIL.td10
-rw-r--r--llvm/lib/Target/DirectX/DXILOpLowering.cpp2
-rw-r--r--llvm/lib/Target/DirectX/DXILShaderFlags.cpp2
-rw-r--r--llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp255
-rw-r--r--llvm/lib/Target/DirectX/DXILTranslateMetadata.h17
-rw-r--r--llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp2
6 files changed, 209 insertions, 79 deletions
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index 44c4830..7ae500a 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1058,6 +1058,16 @@ def WaveActiveOp : DXILOp<119, waveActiveOp> {
IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Max>,
IntrinArgI8<SignedOpKind_Unsigned>
]>,
+ IntrinSelect<int_dx_wave_reduce_min,
+ [
+ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Min>,
+ IntrinArgI8<SignedOpKind_Signed>
+ ]>,
+ IntrinSelect<int_dx_wave_reduce_umin,
+ [
+ IntrinArgIndex<0>, IntrinArgI8<WaveOpKind_Min>,
+ IntrinArgI8<SignedOpKind_Unsigned>
+ ]>,
];
let arguments = [OverloadTy, Int8Ty, Int8Ty];
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index e46a393..8720460 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -904,6 +904,8 @@ public:
case Intrinsic::dx_resource_casthandle:
// NOTE: llvm.dbg.value is supported as is in DXIL.
case Intrinsic::dbg_value:
+ // NOTE: llvm.assume is supported as is in DXIL.
+ case Intrinsic::assume:
case Intrinsic::not_intrinsic:
if (F.use_empty())
F.eraseFromParent();
diff --git a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
index e7e7f2c..ce6e812 100644
--- a/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
+++ b/llvm/lib/Target/DirectX/DXILShaderFlags.cpp
@@ -94,6 +94,8 @@ static bool checkWaveOps(Intrinsic::ID IID) {
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_wave_reduce_max:
case Intrinsic::dx_wave_reduce_umax:
+ case Intrinsic::dx_wave_reduce_min:
+ case Intrinsic::dx_wave_reduce_umin:
return true;
}
}
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
index 1e4797b..cf8b833 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.cpp
@@ -36,9 +36,10 @@ using namespace llvm;
using namespace llvm::dxil;
namespace {
-/// A simple Wrapper DiagnosticInfo that generates Module-level diagnostic
-/// for TranslateMetadata pass
-class DiagnosticInfoTranslateMD : public DiagnosticInfo {
+
+/// A simple wrapper of DiagnosticInfo that generates module-level diagnostic
+/// for the DXILValidateMetadata pass
+class DiagnosticInfoValidateMD : public DiagnosticInfo {
private:
const Twine &Msg;
const Module &Mod;
@@ -47,9 +48,9 @@ public:
/// \p M is the module for which the diagnostic is being emitted. \p Msg is
/// the message to show. Note that this class does not copy this message, so
/// this reference must be valid for the whole life time of the diagnostic.
- DiagnosticInfoTranslateMD(const Module &M,
- const Twine &Msg LLVM_LIFETIME_BOUND,
- DiagnosticSeverity Severity = DS_Error)
+ DiagnosticInfoValidateMD(const Module &M,
+ const Twine &Msg LLVM_LIFETIME_BOUND,
+ DiagnosticSeverity Severity = DS_Error)
: DiagnosticInfo(DK_Unsupported, Severity), Msg(Msg), Mod(M) {}
void print(DiagnosticPrinter &DP) const override {
@@ -57,6 +58,16 @@ public:
}
};
+static void reportError(Module &M, Twine Message,
+ DiagnosticSeverity Severity = DS_Error) {
+ M.getContext().diagnose(DiagnosticInfoValidateMD(M, Message, Severity));
+}
+
+static void reportLoopError(Module &M, Twine Message,
+ DiagnosticSeverity Severity = DS_Error) {
+ reportError(M, Twine("Invalid \"llvm.loop\" metadata: ") + Message, Severity);
+}
+
enum class EntryPropsTag {
ShaderFlags = 0,
GSState,
@@ -314,25 +325,122 @@ static void translateBranchMetadata(Module &M, Instruction *BBTerminatorInst) {
BBTerminatorInst->setMetadata("hlsl.controlflow.hint", nullptr);
}
-static std::array<unsigned, 6> getCompatibleInstructionMDs(llvm::Module &M) {
+// Determines if the metadata node will be compatible with DXIL's loop metadata
+// representation.
+//
+// Reports an error for compatible metadata that is ill-formed.
+static bool isLoopMDCompatible(Module &M, Metadata *MD) {
+ // DXIL only accepts the following loop hints:
+ std::array<StringLiteral, 3> ValidHintNames = {"llvm.loop.unroll.count",
+ "llvm.loop.unroll.disable",
+ "llvm.loop.unroll.full"};
+
+ MDNode *HintMD = dyn_cast<MDNode>(MD);
+ if (!HintMD || HintMD->getNumOperands() == 0)
+ return false;
+
+ auto *HintStr = dyn_cast<MDString>(HintMD->getOperand(0));
+ if (!HintStr)
+ return false;
+
+ if (!llvm::is_contained(ValidHintNames, HintStr->getString()))
+ return false;
+
+ auto ValidCountNode = [](MDNode *CountMD) -> bool {
+ if (CountMD->getNumOperands() == 2)
+ if (auto *Count = dyn_cast<ConstantAsMetadata>(CountMD->getOperand(1)))
+ if (isa<ConstantInt>(Count->getValue()))
+ return true;
+ return false;
+ };
+
+ if (HintStr->getString() == "llvm.loop.unroll.count") {
+ if (!ValidCountNode(HintMD)) {
+ reportLoopError(M, "\"llvm.loop.unroll.count\" must have 2 operands and "
+ "the second must be a constant integer");
+ return false;
+ }
+ } else if (HintMD->getNumOperands() != 1) {
+ reportLoopError(
+ M, "\"llvm.loop.unroll.disable\" and \"llvm.loop.unroll.full\" "
+ "must be provided as a single operand");
+ return false;
+ }
+
+ return true;
+}
+
+static void translateLoopMetadata(Module &M, Instruction *I, MDNode *BaseMD) {
+ // A distinct node has the self-referential form: !0 = !{ !0, ... }
+ auto IsDistinctNode = [](MDNode *Node) -> bool {
+ return Node && Node->getNumOperands() != 0 && Node == Node->getOperand(0);
+ };
+
+ // Set metadata to null to remove empty/ill-formed metadata from instruction
+ if (BaseMD->getNumOperands() == 0 || !IsDistinctNode(BaseMD))
+ return I->setMetadata("llvm.loop", nullptr);
+
+ // It is valid to have a chain of self-refential loop metadata nodes, as
+ // below. We will collapse these into just one when we reconstruct the
+ // metadata.
+ //
+ // Eg:
+ // !0 = !{!0, !1}
+ // !1 = !{!1, !2}
+ // !2 = !{!"llvm.loop.unroll.disable"}
+ //
+ // So, traverse down a potential self-referential chain
+ while (1 < BaseMD->getNumOperands() &&
+ IsDistinctNode(dyn_cast<MDNode>(BaseMD->getOperand(1))))
+ BaseMD = dyn_cast<MDNode>(BaseMD->getOperand(1));
+
+ // To reconstruct a distinct node we create a temporary node that we will
+ // then update to create a self-reference.
+ llvm::TempMDTuple TempNode = llvm::MDNode::getTemporary(M.getContext(), {});
+ SmallVector<Metadata *> CompatibleOperands = {TempNode.get()};
+
+ // Iterate and reconstruct the metadata nodes that contains any hints,
+ // stripping any unrecognized metadata.
+ ArrayRef<MDOperand> Operands = BaseMD->operands();
+ for (auto &Op : Operands.drop_front())
+ if (isLoopMDCompatible(M, Op.get()))
+ CompatibleOperands.push_back(Op.get());
+
+ if (2 < CompatibleOperands.size())
+ reportLoopError(M, "Provided conflicting hints");
+
+ MDNode *CompatibleLoopMD = MDNode::get(M.getContext(), CompatibleOperands);
+ TempNode->replaceAllUsesWith(CompatibleLoopMD);
+
+ I->setMetadata("llvm.loop", CompatibleLoopMD);
+}
+
+using InstructionMDList = std::array<unsigned, 7>;
+
+static InstructionMDList getCompatibleInstructionMDs(llvm::Module &M) {
return {
M.getMDKindID("dx.nonuniform"), M.getMDKindID("dx.controlflow.hints"),
M.getMDKindID("dx.precise"), llvm::LLVMContext::MD_range,
- llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias};
+ llvm::LLVMContext::MD_alias_scope, llvm::LLVMContext::MD_noalias,
+ M.getMDKindID("llvm.loop")};
}
static void translateInstructionMetadata(Module &M) {
// construct allowlist of valid metadata node kinds
- std::array<unsigned, 6> DXILCompatibleMDs = getCompatibleInstructionMDs(M);
+ InstructionMDList DXILCompatibleMDs = getCompatibleInstructionMDs(M);
+ unsigned char MDLoopKind = M.getContext().getMDKindID("llvm.loop");
for (Function &F : M) {
for (BasicBlock &BB : F) {
// This needs to be done first so that "hlsl.controlflow.hints" isn't
- // removed in the whitelist below
+ // removed in the allow-list below
if (auto *I = BB.getTerminator())
translateBranchMetadata(M, I);
for (auto &I : make_early_inc_range(BB)) {
+ if (isa<BranchInst>(I))
+ if (MDNode *LoopMD = I.getMetadata(MDLoopKind))
+ translateLoopMetadata(M, &I, LoopMD);
I.dropUnknownNonDebugMetadata(DXILCompatibleMDs);
}
}
@@ -364,6 +472,16 @@ static void cleanModuleFlags(Module &M) {
M.addModuleFlag(Flag.Behavior, Flag.Key->getString(), Flag.Val);
}
+using GlobalMDList = std::array<StringLiteral, 7>;
+
+// The following are compatible with DXIL but not emit with clang, they can
+// be added when applicable:
+// dx.typeAnnotations, dx.viewIDState, dx.dxrPayloadAnnotations
+static GlobalMDList CompatibleNamedModuleMDs = {
+ "llvm.ident", "llvm.module.flags", "dx.resources", "dx.valver",
+ "dx.shaderModel", "dx.version", "dx.entryPoints",
+};
+
static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
DXILResourceTypeMap &DRTM,
const ModuleShaderFlags &ShaderFlags,
@@ -389,31 +507,23 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
uint64_t CombinedMask = ShaderFlags.getCombinedFlags();
EntryFnMDNodes.emplace_back(
emitTopLevelLibraryNode(M, ResourceMD, CombinedMask));
- } else if (MMDI.EntryPropertyVec.size() > 1) {
- M.getContext().diagnose(DiagnosticInfoTranslateMD(
- M, "Non-library shader: One and only one entry expected"));
- }
+ } else if (1 < MMDI.EntryPropertyVec.size())
+ reportError(M, "Non-library shader: One and only one entry expected");
for (const EntryProperties &EntryProp : MMDI.EntryPropertyVec) {
- const ComputedShaderFlags &EntrySFMask =
- ShaderFlags.getFunctionFlags(EntryProp.Entry);
-
- // If ShaderProfile is Library, mask is already consolidated in the
- // top-level library node. Hence it is not emitted.
uint64_t EntryShaderFlags = 0;
if (MMDI.ShaderProfile != Triple::EnvironmentType::Library) {
- EntryShaderFlags = EntrySFMask;
- if (EntryProp.ShaderStage != MMDI.ShaderProfile) {
- M.getContext().diagnose(DiagnosticInfoTranslateMD(
- M,
- "Shader stage '" +
- Twine(getShortShaderStage(EntryProp.ShaderStage) +
- "' for entry '" + Twine(EntryProp.Entry->getName()) +
- "' different from specified target profile '" +
- Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
- "'"))));
- }
+ EntryShaderFlags = ShaderFlags.getFunctionFlags(EntryProp.Entry);
+ if (EntryProp.ShaderStage != MMDI.ShaderProfile)
+ reportError(
+ M, "Shader stage '" +
+ Twine(getShortShaderStage(EntryProp.ShaderStage)) +
+ "' for entry '" + Twine(EntryProp.Entry->getName()) +
+ "' different from specified target profile '" +
+ Twine(Triple::getEnvironmentTypeName(MMDI.ShaderProfile) +
+ "'"));
}
+
EntryFnMDNodes.emplace_back(emitEntryMD(EntryProp, Signatures, ResourceMD,
EntryShaderFlags,
MMDI.ShaderProfile));
@@ -426,19 +536,17 @@ static void translateGlobalMetadata(Module &M, DXILResourceMap &DRM,
cleanModuleFlags(M);
- // dx.rootsignatures will have been parsed from its metadata form as its
- // binary form as part of the RootSignatureAnalysisWrapper, so safely
- // remove it as it is not recognized in DXIL
- if (NamedMDNode *RootSignature = M.getNamedMetadata("dx.rootsignatures"))
- RootSignature->eraseFromParent();
+ // Finally, strip all module metadata that is not explicitly specified in the
+ // allow-list
+ SmallVector<NamedMDNode *> ToStrip;
- // llvm.errno.tbaa was recently added but is not supported in LLVM 3.7 and
- // causes all tests using the DXIL Validator to fail.
- //
- // This is a temporary fix and should be replaced with a allowlist once
- // we have determined all metadata that the DXIL Validator allows
- if (NamedMDNode *ErrNo = M.getNamedMetadata("llvm.errno.tbaa"))
- ErrNo->eraseFromParent();
+ for (NamedMDNode &NamedMD : M.named_metadata())
+ if (!NamedMD.getName().starts_with("llvm.dbg.") &&
+ !llvm::is_contained(CompatibleNamedModuleMDs, NamedMD.getName()))
+ ToStrip.push_back(&NamedMD);
+
+ for (NamedMDNode *NamedMD : ToStrip)
+ NamedMD->eraseFromParent();
}
PreservedAnalyses DXILTranslateMetadata::run(Module &M,
@@ -454,45 +562,34 @@ PreservedAnalyses DXILTranslateMetadata::run(Module &M,
return PreservedAnalyses::all();
}
-namespace {
-class DXILTranslateMetadataLegacy : public ModulePass {
-public:
- static char ID; // Pass identification, replacement for typeid
- explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
-
- StringRef getPassName() const override { return "DXIL Translate Metadata"; }
-
- void getAnalysisUsage(AnalysisUsage &AU) const override {
- AU.addRequired<DXILResourceTypeWrapperPass>();
- AU.addRequired<DXILResourceWrapperPass>();
- AU.addRequired<ShaderFlagsAnalysisWrapper>();
- AU.addRequired<DXILMetadataAnalysisWrapperPass>();
- AU.addRequired<RootSignatureAnalysisWrapper>();
-
- AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
- AU.addPreserved<DXILResourceBindingWrapperPass>();
- AU.addPreserved<DXILResourceWrapperPass>();
- AU.addPreserved<RootSignatureAnalysisWrapper>();
- AU.addPreserved<ShaderFlagsAnalysisWrapper>();
- }
+void DXILTranslateMetadataLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
+ AU.addRequired<DXILResourceTypeWrapperPass>();
+ AU.addRequired<DXILResourceWrapperPass>();
+ AU.addRequired<ShaderFlagsAnalysisWrapper>();
+ AU.addRequired<DXILMetadataAnalysisWrapperPass>();
+ AU.addRequired<RootSignatureAnalysisWrapper>();
+
+ AU.addPreserved<DXILMetadataAnalysisWrapperPass>();
+ AU.addPreserved<DXILResourceBindingWrapperPass>();
+ AU.addPreserved<DXILResourceWrapperPass>();
+ AU.addPreserved<RootSignatureAnalysisWrapper>();
+ AU.addPreserved<ShaderFlagsAnalysisWrapper>();
+}
- bool runOnModule(Module &M) override {
- DXILResourceMap &DRM =
- getAnalysis<DXILResourceWrapperPass>().getResourceMap();
- DXILResourceTypeMap &DRTM =
- getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
- const ModuleShaderFlags &ShaderFlags =
- getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
- dxil::ModuleMetadataInfo MMDI =
- getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
-
- translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
- translateInstructionMetadata(M);
- return true;
- }
-};
+bool DXILTranslateMetadataLegacy::runOnModule(Module &M) {
+ DXILResourceMap &DRM =
+ getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+ DXILResourceTypeMap &DRTM =
+ getAnalysis<DXILResourceTypeWrapperPass>().getResourceTypeMap();
+ const ModuleShaderFlags &ShaderFlags =
+ getAnalysis<ShaderFlagsAnalysisWrapper>().getShaderFlags();
+ dxil::ModuleMetadataInfo MMDI =
+ getAnalysis<DXILMetadataAnalysisWrapperPass>().getModuleMetadata();
-} // namespace
+ translateGlobalMetadata(M, DRM, DRTM, ShaderFlags, MMDI);
+ translateInstructionMetadata(M);
+ return true;
+}
char DXILTranslateMetadataLegacy::ID = 0;
diff --git a/llvm/lib/Target/DirectX/DXILTranslateMetadata.h b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h
index 4c1ffac..cfb8aaa 100644
--- a/llvm/lib/Target/DirectX/DXILTranslateMetadata.h
+++ b/llvm/lib/Target/DirectX/DXILTranslateMetadata.h
@@ -10,6 +10,7 @@
#define LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H
#include "llvm/IR/PassManager.h"
+#include "llvm/Pass.h"
namespace llvm {
@@ -20,6 +21,22 @@ public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &);
};
+/// Wrapper pass for the legacy pass manager.
+///
+/// This is required because the passes that will depend on this are codegen
+/// passes which run through the legacy pass manager.
+class DXILTranslateMetadataLegacy : public ModulePass {
+public:
+ static char ID; // Pass identification, replacement for typeid
+ explicit DXILTranslateMetadataLegacy() : ModulePass(ID) {}
+
+ StringRef getPassName() const override { return "DXIL Translate Metadata"; }
+
+ void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+ bool runOnModule(Module &M) override;
+};
+
} // namespace llvm
#endif // LLVM_TARGET_DIRECTX_DXILTRANSLATEMETADATA_H
diff --git a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
index 68fd3e0..60dfd96 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetTransformInfo.cpp
@@ -55,8 +55,10 @@ bool DirectXTTIImpl::isTargetIntrinsicTriviallyScalarizable(
case Intrinsic::dx_splitdouble:
case Intrinsic::dx_wave_readlane:
case Intrinsic::dx_wave_reduce_max:
+ case Intrinsic::dx_wave_reduce_min:
case Intrinsic::dx_wave_reduce_sum:
case Intrinsic::dx_wave_reduce_umax:
+ case Intrinsic::dx_wave_reduce_umin:
case Intrinsic::dx_wave_reduce_usum:
case Intrinsic::dx_imad:
case Intrinsic::dx_umad: