aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Target/DirectX/DXILRootSignature.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Target/DirectX/DXILRootSignature.cpp')
-rw-r--r--llvm/lib/Target/DirectX/DXILRootSignature.cpp471
1 files changed, 8 insertions, 463 deletions
diff --git a/llvm/lib/Target/DirectX/DXILRootSignature.cpp b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
index dfc8162..ebdfcaa 100644
--- a/llvm/lib/Target/DirectX/DXILRootSignature.cpp
+++ b/llvm/lib/Target/DirectX/DXILRootSignature.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/Twine.h"
#include "llvm/Analysis/DXILMetadataAnalysis.h"
#include "llvm/BinaryFormat/DXContainer.h"
+#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DiagnosticInfo.h"
@@ -29,25 +30,10 @@
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
-#include <optional>
-#include <utility>
using namespace llvm;
using namespace llvm::dxil;
-static bool reportError(LLVMContext *Ctx, Twine Message,
- DiagnosticSeverity Severity = DS_Error) {
- Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
- return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
- uint32_t Value) {
- Ctx->diagnose(DiagnosticInfoGeneric(
- "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
- return true;
-}
-
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
@@ -56,453 +42,10 @@ static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
return std::nullopt;
}
-static std::optional<float> extractMdFloatValue(MDNode *Node,
- unsigned int OpId) {
- if (auto *CI = mdconst::dyn_extract<ConstantFP>(Node->getOperand(OpId).get()))
- return CI->getValueAPF().convertToFloat();
- return std::nullopt;
-}
-
-static std::optional<StringRef> extractMdStringValue(MDNode *Node,
- unsigned int OpId) {
- MDString *NodeText = dyn_cast<MDString>(Node->getOperand(OpId));
- if (NodeText == nullptr)
- return std::nullopt;
- return NodeText->getString();
-}
-
-static bool parseRootFlags(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode) {
-
- if (RootFlagNode->getNumOperands() != 2)
- return reportError(Ctx, "Invalid format for RootFlag Element");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
- RSD.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for RootFlag");
-
- return false;
-}
-
-static bool parseRootConstants(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode) {
-
- if (RootConstantNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for RootConstants Element");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- // The parameter offset doesn't matter here - we recalculate it during
- // serialization Header.ParameterOffset = 0;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- dxbc::RTS0::v1::RootConstants Constants;
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
- Constants.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
- Constants.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
- Constants.Num32BitValues = *Val;
- else
- return reportError(Ctx, "Invalid value for Num32BitValues");
-
- RSD.ParametersContainer.addParameter(Header, Constants);
-
- return false;
-}
-
-static bool parseRootDescriptors(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode,
- RootSignatureElementKind ElementKind) {
- assert(ElementKind == RootSignatureElementKind::SRV ||
- ElementKind == RootSignatureElementKind::UAV ||
- ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor "
- "element kind.");
- if (RootDescriptorNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for Root Descriptor Element");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- switch (ElementKind) {
- case RootSignatureElementKind::SRV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
- break;
- case RootSignatureElementKind::UAV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
- break;
- case RootSignatureElementKind::CBV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
- break;
- default:
- llvm_unreachable("invalid Root Descriptor kind");
- break;
- }
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- dxbc::RTS0::v2::RootDescriptor Descriptor;
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
- Descriptor.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
- Descriptor.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
- }
- assert(RSD.Version > 1);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
- Descriptor.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
-
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
-}
-
-static bool parseDescriptorRange(LLVMContext *Ctx,
- mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode) {
-
- if (RangeDescriptorNode->getNumOperands() != 6)
- return reportError(Ctx, "Invalid format for Descriptor Range");
-
- dxbc::RTS0::v2::DescriptorRange Range;
-
- std::optional<StringRef> ElementText =
- extractMdStringValue(RangeDescriptorNode, 0);
-
- if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
-
- Range.RangeType =
- StringSwitch<uint32_t>(*ElementText)
- .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
- .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
- .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
- .Case("Sampler",
- llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
- .Default(~0U);
-
- if (Range.RangeType == ~0U)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
- Range.NumDescriptors = *Val;
- else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
- Range.BaseShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
- Range.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
- Range.OffsetInDescriptorsFromTableStart = *Val;
- else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
- Range.Flags = *Val;
- else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
-
- Table.Ranges.push_back(Range);
- return false;
-}
-
-static bool parseDescriptorTable(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode) {
- const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
- if (NumOperands < 2)
- return reportError(Ctx, "Invalid format for Descriptor Table");
-
- dxbc::RTS0::v1::RootParameterHeader Header;
- if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- mcdxbc::DescriptorTable Table;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
-
- for (unsigned int I = 2; I < NumOperands; I++) {
- MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
- if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
-
- if (parseDescriptorRange(Ctx, Table, Element))
- return true;
- }
-
- RSD.ParametersContainer.addParameter(Header, Table);
- return false;
-}
-
-static bool parseStaticSampler(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode) {
- if (StaticSamplerNode->getNumOperands() != 14)
- return reportError(Ctx, "Invalid format for Static Sampler");
-
- dxbc::RTS0::v1::StaticSampler Sampler;
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
- Sampler.Filter = *Val;
- else
- return reportError(Ctx, "Invalid value for Filter");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
- Sampler.AddressU = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressU");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
- Sampler.AddressV = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressV");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
- Sampler.AddressW = *Val;
- else
- return reportError(Ctx, "Invalid value for AddressW");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
- Sampler.MipLODBias = *Val;
- else
- return reportError(Ctx, "Invalid value for MipLODBias");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
- Sampler.MaxAnisotropy = *Val;
- else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
- Sampler.ComparisonFunc = *Val;
- else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
- Sampler.BorderColor = *Val;
- else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
- Sampler.MinLOD = *Val;
- else
- return reportError(Ctx, "Invalid value for MinLOD");
-
- if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
- Sampler.MaxLOD = *Val;
- else
- return reportError(Ctx, "Invalid value for MaxLOD");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
- Sampler.ShaderRegister = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderRegister");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
- Sampler.RegisterSpace = *Val;
- else
- return reportError(Ctx, "Invalid value for RegisterSpace");
-
- if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
- Sampler.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
-
- RSD.StaticSamplers.push_back(Sampler);
- return false;
-}
-
-static bool parseRootSignatureElement(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element) {
- std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
- if (!ElementText.has_value())
- return reportError(Ctx, "Invalid format for Root Element");
-
- RootSignatureElementKind ElementKind =
- StringSwitch<RootSignatureElementKind>(*ElementText)
- .Case("RootFlags", RootSignatureElementKind::RootFlags)
- .Case("RootConstants", RootSignatureElementKind::RootConstants)
- .Case("RootCBV", RootSignatureElementKind::CBV)
- .Case("RootSRV", RootSignatureElementKind::SRV)
- .Case("RootUAV", RootSignatureElementKind::UAV)
- .Case("DescriptorTable", RootSignatureElementKind::DescriptorTable)
- .Case("StaticSampler", RootSignatureElementKind::StaticSamplers)
- .Default(RootSignatureElementKind::Error);
-
- switch (ElementKind) {
-
- case RootSignatureElementKind::RootFlags:
- return parseRootFlags(Ctx, RSD, Element);
- case RootSignatureElementKind::RootConstants:
- return parseRootConstants(Ctx, RSD, Element);
- case RootSignatureElementKind::CBV:
- case RootSignatureElementKind::SRV:
- case RootSignatureElementKind::UAV:
- return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
- case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(Ctx, RSD, Element);
- case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(Ctx, RSD, Element);
- case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
- }
-
- llvm_unreachable("Unhandled RootSignatureElementKind enum.");
-}
-
-static bool parse(LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *Node) {
- bool HasError = false;
-
- // Loop through the Root Elements of the root signature.
- for (const auto &Operand : Node->operands()) {
- MDNode *Element = dyn_cast<MDNode>(Operand);
- if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
-
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element);
- }
-
- return HasError;
-}
-
-static bool validate(LLVMContext *Ctx, const mcdxbc::RootSignatureDesc &RSD) {
-
- if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
- }
-
- if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
- }
-
- for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
- if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
-
- assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
- "Invalid value for ParameterType");
-
- switch (Info.Header.ParameterType) {
-
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
- RSD.ParametersContainer.getRootDescriptor(Info.Location);
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
-
- if (RSD.Version > 1) {
- if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
- return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
- }
- break;
- }
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
- const mcdxbc::DescriptorTable &Table =
- RSD.ParametersContainer.getDescriptorTable(Info.Location);
- for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
- if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
- return reportValueError(Ctx, "RangeType", Range.RangeType);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
-
- if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
-
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
- RSD.Version, Range.RangeType, Range.Flags))
- return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
- }
- break;
- }
- }
- }
-
- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
- if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- return reportValueError(Ctx, "Filter", Sampler.Filter);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
- return reportValueError(Ctx, "AddressU", Sampler.AddressU);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
- return reportValueError(Ctx, "AddressV", Sampler.AddressV);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
- return reportValueError(Ctx, "AddressW", Sampler.AddressW);
-
- if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
-
- if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
- return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
-
- if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
-
- if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
-
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
-
- if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Sampler.ShaderVisibility);
- }
-
- return false;
+static bool reportError(LLVMContext *Ctx, Twine Message,
+ DiagnosticSeverity Severity = DS_Error) {
+ Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
+ return true;
}
static SmallDenseMap<const Function *, mcdxbc::RootSignatureDesc>
@@ -584,7 +127,9 @@ analyzeModule(Module &M) {
// static sampler offset is calculated when writting dxcontainer.
RSD.StaticSamplersOffset = 0u;
- if (parse(Ctx, RSD, RootElementListNode) || validate(Ctx, RSD)) {
+ hlsl::rootsig::MetadataParser MDParser(RootElementListNode);
+
+ if (MDParser.ParseRootSignature(Ctx, RSD)) {
return RSDMap;
}