aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Frontend/HLSL
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Frontend/HLSL')
-rw-r--r--llvm/lib/Frontend/HLSL/CMakeLists.txt1
-rw-r--r--llvm/lib/Frontend/HLSL/HLSLBinding.cpp142
-rw-r--r--llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp55
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp546
-rw-r--r--llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp136
5 files changed, 451 insertions, 429 deletions
diff --git a/llvm/lib/Frontend/HLSL/CMakeLists.txt b/llvm/lib/Frontend/HLSL/CMakeLists.txt
index 5343469..3d22577 100644
--- a/llvm/lib/Frontend/HLSL/CMakeLists.txt
+++ b/llvm/lib/Frontend/HLSL/CMakeLists.txt
@@ -1,5 +1,6 @@
add_llvm_component_library(LLVMFrontendHLSL
CBuffer.cpp
+ HLSLBinding.cpp
HLSLResource.cpp
HLSLRootSignature.cpp
RootSignatureMetadata.cpp
diff --git a/llvm/lib/Frontend/HLSL/HLSLBinding.cpp b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp
new file mode 100644
index 0000000..401402f
--- /dev/null
+++ b/llvm/lib/Frontend/HLSL/HLSLBinding.cpp
@@ -0,0 +1,142 @@
+//===- HLSLBinding.cpp - Representation for resource bindings in HLSL -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Frontend/HLSL/HLSLBinding.h"
+#include "llvm/ADT/STLExtras.h"
+
+using namespace llvm;
+using namespace hlsl;
+
+std::optional<uint32_t>
+BindingInfo::findAvailableBinding(dxil::ResourceClass RC, uint32_t Space,
+ int32_t Size) {
+ BindingSpaces &BS = getBindingSpaces(RC);
+ RegisterSpace &RS = BS.getOrInsertSpace(Space);
+ return RS.findAvailableBinding(Size);
+}
+
+BindingInfo::RegisterSpace &
+BindingInfo::BindingSpaces::getOrInsertSpace(uint32_t Space) {
+ for (auto It = Spaces.begin(), End = Spaces.end(); It != End; ++It) {
+ if (It->Space == Space)
+ return *It;
+ if (It->Space < Space)
+ continue;
+ return *Spaces.insert(It, Space);
+ }
+ return Spaces.emplace_back(Space);
+}
+
+std::optional<uint32_t>
+BindingInfo::RegisterSpace::findAvailableBinding(int32_t Size) {
+ assert((Size == -1 || Size > 0) && "invalid size");
+
+ if (FreeRanges.empty())
+ return std::nullopt;
+
+ // unbounded array
+ if (Size == -1) {
+ BindingRange &Last = FreeRanges.back();
+ if (Last.UpperBound != ~0u)
+ // this space is already occupied by an unbounded array
+ return std::nullopt;
+ uint32_t RegSlot = Last.LowerBound;
+ FreeRanges.pop_back();
+ return RegSlot;
+ }
+
+ // single resource or fixed-size array
+ for (BindingRange &R : FreeRanges) {
+ // compare the size as uint64_t to prevent overflow for range (0, ~0u)
+ if ((uint64_t)R.UpperBound - R.LowerBound + 1 < (uint64_t)Size)
+ continue;
+ uint32_t RegSlot = R.LowerBound;
+ // This might create a range where (LowerBound == UpperBound + 1). When
+ // that happens, the next time this function is called the range will
+ // skipped over by the check above (at this point Size is always > 0).
+ R.LowerBound += Size;
+ return RegSlot;
+ }
+
+ return std::nullopt;
+}
+
+BindingInfo BindingInfoBuilder::calculateBindingInfo(
+ llvm::function_ref<void(const BindingInfoBuilder &Builder,
+ const Binding &Overlapping)>
+ ReportOverlap) {
+ // sort all the collected bindings
+ llvm::stable_sort(Bindings);
+
+ // remove duplicates
+ Binding *NewEnd = llvm::unique(Bindings);
+ if (NewEnd != Bindings.end())
+ Bindings.erase(NewEnd, Bindings.end());
+
+ BindingInfo Info;
+
+ // Go over the sorted bindings and build up lists of free register ranges
+ // for each binding type and used spaces. Bindings are sorted by resource
+ // class, space, and lower bound register slot.
+ BindingInfo::BindingSpaces *BS =
+ &Info.getBindingSpaces(dxil::ResourceClass::SRV);
+ for (const Binding &B : Bindings) {
+ if (BS->RC != B.RC)
+ // move to the next resource class spaces
+ BS = &Info.getBindingSpaces(B.RC);
+
+ BindingInfo::RegisterSpace *S = BS->Spaces.empty()
+ ? &BS->Spaces.emplace_back(B.Space)
+ : &BS->Spaces.back();
+ assert(S->Space <= B.Space && "bindings not sorted correctly?");
+ if (B.Space != S->Space)
+ // add new space
+ S = &BS->Spaces.emplace_back(B.Space);
+
+ // The space is full - there are no free slots left, or the rest of the
+ // slots are taken by an unbounded array. Report the overlapping to the
+ // caller.
+ if (S->FreeRanges.empty() || S->FreeRanges.back().UpperBound < ~0u) {
+ ReportOverlap(*this, B);
+ continue;
+ }
+ // adjust the last free range lower bound, split it in two, or remove it
+ BindingInfo::BindingRange &LastFreeRange = S->FreeRanges.back();
+ if (LastFreeRange.LowerBound == B.LowerBound) {
+ if (B.UpperBound < ~0u)
+ LastFreeRange.LowerBound = B.UpperBound + 1;
+ else
+ S->FreeRanges.pop_back();
+ } else if (LastFreeRange.LowerBound < B.LowerBound) {
+ LastFreeRange.UpperBound = B.LowerBound - 1;
+ if (B.UpperBound < ~0u)
+ S->FreeRanges.emplace_back(B.UpperBound + 1, ~0u);
+ } else {
+ // We don't have room here. Report the overlapping binding to the caller
+ // and mark any extra space this binding would use as unavailable.
+ ReportOverlap(*this, B);
+ if (B.UpperBound < ~0u)
+ LastFreeRange.LowerBound =
+ std::max(LastFreeRange.LowerBound, B.UpperBound + 1);
+ else
+ S->FreeRanges.pop_back();
+ }
+ }
+
+ return Info;
+}
+
+const Binding &
+BindingInfoBuilder::findOverlapping(const Binding &ReportedBinding) const {
+ for (const Binding &Other : Bindings)
+ if (ReportedBinding.LowerBound <= Other.UpperBound &&
+ Other.LowerBound <= ReportedBinding.UpperBound)
+ return Other;
+
+ llvm_unreachable("Searching for overlap for binding that does not overlap");
+}
diff --git a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
index 78c20a6..92c62b8 100644
--- a/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
+++ b/llvm/lib/Frontend/HLSL/HLSLRootSignature.cpp
@@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//
#include "llvm/Frontend/HLSL/HLSLRootSignature.h"
+#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ScopedPrinter.h"
namespace llvm {
@@ -18,24 +19,6 @@ namespace hlsl {
namespace rootsig {
template <typename T>
-static std::optional<StringRef> getEnumName(const T Value,
- ArrayRef<EnumEntry<T>> Enums) {
- for (const auto &EnumItem : Enums)
- if (EnumItem.Value == Value)
- return EnumItem.Name;
- return std::nullopt;
-}
-
-template <typename T>
-static raw_ostream &printEnum(raw_ostream &OS, const T Value,
- ArrayRef<EnumEntry<T>> Enums) {
- auto MaybeName = getEnumName(Value, Enums);
- if (MaybeName)
- OS << *MaybeName;
- return OS;
-}
-
-template <typename T>
static raw_ostream &printFlags(raw_ostream &OS, const T Value,
ArrayRef<EnumEntry<T>> Flags) {
bool FlagSet = false;
@@ -46,9 +29,9 @@ static raw_ostream &printFlags(raw_ostream &OS, const T Value,
if (FlagSet)
OS << " | ";
- auto MaybeFlag = getEnumName(T(Bit), Flags);
- if (MaybeFlag)
- OS << *MaybeFlag;
+ StringRef MaybeFlag = enumToStringRef(T(Bit), Flags);
+ if (!MaybeFlag.empty())
+ OS << MaybeFlag;
else
OS << "invalid: " << Bit;
@@ -70,58 +53,49 @@ static const EnumEntry<RegisterType> RegisterNames[] = {
};
static raw_ostream &operator<<(raw_ostream &OS, const Register &Reg) {
- printEnum(OS, Reg.ViewType, ArrayRef(RegisterNames));
- OS << Reg.Number;
+ OS << enumToStringRef(Reg.ViewType, ArrayRef(RegisterNames)) << Reg.Number;
return OS;
}
static raw_ostream &operator<<(raw_ostream &OS,
const llvm::dxbc::ShaderVisibility &Visibility) {
- printEnum(OS, Visibility, dxbc::getShaderVisibility());
+ OS << enumToStringRef(Visibility, dxbc::getShaderVisibility());
return OS;
}
static raw_ostream &operator<<(raw_ostream &OS,
const llvm::dxbc::SamplerFilter &Filter) {
- printEnum(OS, Filter, dxbc::getSamplerFilters());
+ OS << enumToStringRef(Filter, dxbc::getSamplerFilters());
return OS;
}
static raw_ostream &operator<<(raw_ostream &OS,
const dxbc::TextureAddressMode &Address) {
- printEnum(OS, Address, dxbc::getTextureAddressModes());
+ OS << enumToStringRef(Address, dxbc::getTextureAddressModes());
return OS;
}
static raw_ostream &operator<<(raw_ostream &OS,
const dxbc::ComparisonFunc &CompFunc) {
- printEnum(OS, CompFunc, dxbc::getComparisonFuncs());
+ OS << enumToStringRef(CompFunc, dxbc::getComparisonFuncs());
return OS;
}
static raw_ostream &operator<<(raw_ostream &OS,
const dxbc::StaticBorderColor &BorderColor) {
- printEnum(OS, BorderColor, dxbc::getStaticBorderColors());
+ OS << enumToStringRef(BorderColor, dxbc::getStaticBorderColors());
return OS;
}
-static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
- {"CBV", dxil::ResourceClass::CBuffer},
- {"SRV", dxil::ResourceClass::SRV},
- {"UAV", dxil::ResourceClass::UAV},
- {"Sampler", dxil::ResourceClass::Sampler},
-};
-
-static raw_ostream &operator<<(raw_ostream &OS, const ClauseType &Type) {
- printEnum(OS, dxil::ResourceClass(llvm::to_underlying(Type)),
- ArrayRef(ResourceClassNames));
-
+static raw_ostream &operator<<(raw_ostream &OS,
+ const dxil::ResourceClass &Type) {
+ OS << dxil::getResourceClassName(Type);
return OS;
}
@@ -179,8 +153,7 @@ raw_ostream &operator<<(raw_ostream &OS, const DescriptorTableClause &Clause) {
}
raw_ostream &operator<<(raw_ostream &OS, const RootDescriptor &Descriptor) {
- ClauseType Type = ClauseType(llvm::to_underlying(Descriptor.Type));
- OS << "Root" << Type << "(" << Descriptor.Reg
+ OS << "Root" << Descriptor.Type << "(" << Descriptor.Reg
<< ", space = " << Descriptor.Space
<< ", visibility = " << Descriptor.Visibility
<< ", flags = " << Descriptor.Flags << ")";
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
index 53f5934..a5a92cb 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureMetadata.cpp
@@ -13,15 +13,22 @@
#include "llvm/Frontend/HLSL/RootSignatureMetadata.h"
#include "llvm/Frontend/HLSL/RootSignatureValidations.h"
-#include "llvm/IR/DiagnosticInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Metadata.h"
+#include "llvm/Support/DXILABI.h"
#include "llvm/Support/ScopedPrinter.h"
+using namespace llvm;
+
namespace llvm {
namespace hlsl {
namespace rootsig {
+char GenericRSMetadataError::ID;
+char InvalidRSMetadataFormat::ID;
+char InvalidRSMetadataValue::ID;
+template <typename T> char RootSignatureValidationError<T>::ID;
+
static std::optional<uint32_t> extractMdIntValue(MDNode *Node,
unsigned int OpId) {
if (auto *CI =
@@ -45,31 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
return NodeText->getString();
}
-static bool reportError(LLVMContext *Ctx, Twine Message,
- DiagnosticSeverity Severity = DS_Error) {
- Ctx->diagnose(DiagnosticInfoGeneric(Message, Severity));
- return true;
-}
-
-static bool reportValueError(LLVMContext *Ctx, Twine ParamName,
- uint32_t Value) {
- Ctx->diagnose(DiagnosticInfoGeneric(
- "Invalid value for " + ParamName + ": " + Twine(Value), DS_Error));
- return true;
-}
-
-static const EnumEntry<dxil::ResourceClass> ResourceClassNames[] = {
- {"CBV", dxil::ResourceClass::CBuffer},
- {"SRV", dxil::ResourceClass::SRV},
- {"UAV", dxil::ResourceClass::UAV},
- {"Sampler", dxil::ResourceClass::Sampler},
-};
-
-static std::optional<StringRef> getResourceName(dxil::ResourceClass Class) {
- for (const auto &ClassEnum : ResourceClassNames)
- if (ClassEnum.Value == Class)
- return ClassEnum.Name;
- return std::nullopt;
+static Expected<dxbc::ShaderVisibility>
+extractShaderVisibility(MDNode *Node, unsigned int OpId) {
+ if (std::optional<uint32_t> Val = extractMdIntValue(Node, OpId)) {
+ if (!dxbc::isValidShaderVisibility(*Val))
+ return make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", *Val);
+ return dxbc::ShaderVisibility(*Val);
+ }
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
}
namespace {
@@ -120,7 +111,7 @@ MDNode *MetadataBuilder::BuildRootFlags(const dxbc::RootFlags &Flags) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "RootFlags"),
- ConstantAsMetadata::get(Builder.getInt32(llvm::to_underlying(Flags))),
+ ConstantAsMetadata::get(Builder.getInt32(to_underlying(Flags))),
};
return MDNode::get(Ctx, Operands);
}
@@ -130,7 +121,7 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
Metadata *Operands[] = {
MDString::get(Ctx, "RootConstants"),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Constants.Visibility))),
+ Builder.getInt32(to_underlying(Constants.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Constants.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Space)),
ConstantAsMetadata::get(Builder.getInt32(Constants.Num32BitConstants)),
@@ -140,18 +131,17 @@ MDNode *MetadataBuilder::BuildRootConstants(const RootConstants &Constants) {
MDNode *MetadataBuilder::BuildRootDescriptor(const RootDescriptor &Descriptor) {
IRBuilder<> Builder(Ctx);
- std::optional<StringRef> ResName = getResourceName(
- dxil::ResourceClass(llvm::to_underlying(Descriptor.Type)));
- assert(ResName && "Provided an invalid Resource Class");
- llvm::SmallString<7> Name({"Root", *ResName});
+ StringRef ResName = dxil::getResourceClassName(Descriptor.Type);
+ assert(!ResName.empty() && "Provided an invalid Resource Class");
+ SmallString<7> Name({"Root", ResName});
Metadata *Operands[] = {
MDString::get(Ctx, Name),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Descriptor.Visibility))),
+ Builder.getInt32(to_underlying(Descriptor.Visibility))),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Descriptor.Space)),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Descriptor.Flags))),
+ Builder.getInt32(to_underlying(Descriptor.Flags))),
};
return MDNode::get(Ctx, Operands);
}
@@ -162,7 +152,7 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
// Set the mandatory arguments
TableOperands.push_back(MDString::get(Ctx, "DescriptorTable"));
TableOperands.push_back(ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Table.Visibility))));
+ Builder.getInt32(to_underlying(Table.Visibility))));
// Remaining operands are references to the table's clauses. The in-memory
// representation of the Root Elements created from parsing will ensure that
@@ -181,17 +171,15 @@ MDNode *MetadataBuilder::BuildDescriptorTable(const DescriptorTable &Table) {
MDNode *MetadataBuilder::BuildDescriptorTableClause(
const DescriptorTableClause &Clause) {
IRBuilder<> Builder(Ctx);
- std::optional<StringRef> ResName =
- getResourceName(dxil::ResourceClass(llvm::to_underlying(Clause.Type)));
- assert(ResName && "Provided an invalid Resource Class");
+ StringRef ResName = dxil::getResourceClassName(Clause.Type);
+ assert(!ResName.empty() && "Provided an invalid Resource Class");
Metadata *Operands[] = {
- MDString::get(Ctx, *ResName),
+ MDString::get(Ctx, ResName),
ConstantAsMetadata::get(Builder.getInt32(Clause.NumDescriptors)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Space)),
ConstantAsMetadata::get(Builder.getInt32(Clause.Offset)),
- ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Clause.Flags))),
+ ConstantAsMetadata::get(Builder.getInt32(to_underlying(Clause.Flags))),
};
return MDNode::get(Ctx, Operands);
}
@@ -200,151 +188,139 @@ MDNode *MetadataBuilder::BuildStaticSampler(const StaticSampler &Sampler) {
IRBuilder<> Builder(Ctx);
Metadata *Operands[] = {
MDString::get(Ctx, "StaticSampler"),
+ ConstantAsMetadata::get(Builder.getInt32(to_underlying(Sampler.Filter))),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.Filter))),
+ Builder.getInt32(to_underlying(Sampler.AddressU))),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.AddressU))),
+ Builder.getInt32(to_underlying(Sampler.AddressV))),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.AddressV))),
+ Builder.getInt32(to_underlying(Sampler.AddressW))),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.AddressW))),
- ConstantAsMetadata::get(llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx),
- Sampler.MipLODBias)),
+ ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MipLODBias)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.MaxAnisotropy)),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.CompFunc))),
+ Builder.getInt32(to_underlying(Sampler.CompFunc))),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.BorderColor))),
+ Builder.getInt32(to_underlying(Sampler.BorderColor))),
ConstantAsMetadata::get(
- llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MinLOD)),
+ ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MinLOD)),
ConstantAsMetadata::get(
- llvm::ConstantFP::get(llvm::Type::getFloatTy(Ctx), Sampler.MaxLOD)),
+ ConstantFP::get(Type::getFloatTy(Ctx), Sampler.MaxLOD)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Reg.Number)),
ConstantAsMetadata::get(Builder.getInt32(Sampler.Space)),
ConstantAsMetadata::get(
- Builder.getInt32(llvm::to_underlying(Sampler.Visibility))),
+ Builder.getInt32(to_underlying(Sampler.Visibility))),
};
return MDNode::get(Ctx, Operands);
}
-bool MetadataParser::parseRootFlags(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootFlagNode) {
-
+Error MetadataParser::parseRootFlags(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootFlagNode) {
if (RootFlagNode->getNumOperands() != 2)
- return reportError(Ctx, "Invalid format for RootFlag Element");
+ return make_error<InvalidRSMetadataFormat>("RootFlag Element");
if (std::optional<uint32_t> Val = extractMdIntValue(RootFlagNode, 1))
RSD.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for RootFlag");
+ return make_error<InvalidRSMetadataValue>("RootFlag");
- return false;
+ return Error::success();
}
-bool MetadataParser::parseRootConstants(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootConstantNode) {
-
+Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *RootConstantNode) {
if (RootConstantNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for RootConstants Element");
+ return make_error<InvalidRSMetadataFormat>("RootConstants Element");
- dxbc::RTS0::v1::RootParameterHeader Header;
- // The parameter offset doesn't matter here - we recalculate it during
- // serialization Header.ParameterOffset = 0;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::Constants32Bit);
-
- if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(RootConstantNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
- dxbc::RTS0::v1::RootConstants Constants;
+ mcdxbc::RootConstants Constants;
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 2))
Constants.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 3))
Constants.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RootConstantNode, 4))
Constants.Num32BitValues = *Val;
else
- return reportError(Ctx, "Invalid value for Num32BitValues");
+ return make_error<InvalidRSMetadataValue>("Num32BitValues");
- RSD.ParametersContainer.addParameter(Header, Constants);
+ RSD.ParametersContainer.addParameter(dxbc::RootParameterType::Constants32Bit,
+ *Visibility, Constants);
- return false;
+ return Error::success();
}
-bool MetadataParser::parseRootDescriptors(
- LLVMContext *Ctx, mcdxbc::RootSignatureDesc &RSD,
- MDNode *RootDescriptorNode, RootSignatureElementKind ElementKind) {
- assert(ElementKind == RootSignatureElementKind::SRV ||
- ElementKind == RootSignatureElementKind::UAV ||
- ElementKind == RootSignatureElementKind::CBV &&
- "parseRootDescriptors should only be called with RootDescriptor "
- "element kind.");
+Error MetadataParser::parseRootDescriptors(
+ mcdxbc::RootSignatureDesc &RSD, MDNode *RootDescriptorNode,
+ RootSignatureElementKind ElementKind) {
+ assert((ElementKind == RootSignatureElementKind::SRV ||
+ ElementKind == RootSignatureElementKind::UAV ||
+ ElementKind == RootSignatureElementKind::CBV) &&
+ "parseRootDescriptors should only be called with RootDescriptor "
+ "element kind.");
if (RootDescriptorNode->getNumOperands() != 5)
- return reportError(Ctx, "Invalid format for Root Descriptor Element");
+ return make_error<InvalidRSMetadataFormat>("Root Descriptor Element");
- dxbc::RTS0::v1::RootParameterHeader Header;
+ dxbc::RootParameterType Type;
switch (ElementKind) {
case RootSignatureElementKind::SRV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::SRV);
+ Type = dxbc::RootParameterType::SRV;
break;
case RootSignatureElementKind::UAV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::UAV);
+ Type = dxbc::RootParameterType::UAV;
break;
case RootSignatureElementKind::CBV:
- Header.ParameterType = llvm::to_underlying(dxbc::RootParameterType::CBV);
+ Type = dxbc::RootParameterType::CBV;
break;
default:
llvm_unreachable("invalid Root Descriptor kind");
break;
}
- if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(RootDescriptorNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
- dxbc::RTS0::v2::RootDescriptor Descriptor;
+ mcdxbc::RootDescriptor Descriptor;
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 2))
Descriptor.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 3))
Descriptor.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (RSD.Version == 1) {
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
+ return Error::success();
}
assert(RSD.Version > 1);
if (std::optional<uint32_t> Val = extractMdIntValue(RootDescriptorNode, 4))
Descriptor.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Root Descriptor Flags");
+ return make_error<InvalidRSMetadataValue>("Root Descriptor Flags");
- RSD.ParametersContainer.addParameter(Header, Descriptor);
- return false;
+ RSD.ParametersContainer.addParameter(Type, *Visibility, Descriptor);
+ return Error::success();
}
-bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
- mcdxbc::DescriptorTable &Table,
- MDNode *RangeDescriptorNode) {
-
+Error MetadataParser::parseDescriptorRange(mcdxbc::DescriptorTable &Table,
+ MDNode *RangeDescriptorNode) {
if (RangeDescriptorNode->getNumOperands() != 6)
- return reportError(Ctx, "Invalid format for Descriptor Range");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
dxbc::RTS0::v2::DescriptorRange Range;
@@ -352,162 +328,159 @@ bool MetadataParser::parseDescriptorRange(LLVMContext *Ctx,
extractMdStringValue(RangeDescriptorNode, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Descriptor Range, first element is not a string.");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Range");
Range.RangeType =
StringSwitch<uint32_t>(*ElementText)
- .Case("CBV", llvm::to_underlying(dxbc::DescriptorRangeType::CBV))
- .Case("SRV", llvm::to_underlying(dxbc::DescriptorRangeType::SRV))
- .Case("UAV", llvm::to_underlying(dxbc::DescriptorRangeType::UAV))
- .Case("Sampler",
- llvm::to_underlying(dxbc::DescriptorRangeType::Sampler))
+ .Case("CBV", to_underlying(dxbc::DescriptorRangeType::CBV))
+ .Case("SRV", to_underlying(dxbc::DescriptorRangeType::SRV))
+ .Case("UAV", to_underlying(dxbc::DescriptorRangeType::UAV))
+ .Case("Sampler", to_underlying(dxbc::DescriptorRangeType::Sampler))
.Default(~0U);
if (Range.RangeType == ~0U)
- return reportError(Ctx, "Invalid Descriptor Range type: " + *ElementText);
+ return make_error<GenericRSMetadataError>("Invalid Descriptor Range type.",
+ RangeDescriptorNode);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 1))
Range.NumDescriptors = *Val;
else
- return reportError(Ctx, "Invalid value for Number of Descriptor in Range");
+ return make_error<GenericRSMetadataError>("Number of Descriptor in Range",
+ RangeDescriptorNode);
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 2))
Range.BaseShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for BaseShaderRegister");
+ return make_error<InvalidRSMetadataValue>("BaseShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 3))
Range.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 4))
Range.OffsetInDescriptorsFromTableStart = *Val;
else
- return reportError(Ctx,
- "Invalid value for OffsetInDescriptorsFromTableStart");
+ return make_error<InvalidRSMetadataValue>(
+ "OffsetInDescriptorsFromTableStart");
if (std::optional<uint32_t> Val = extractMdIntValue(RangeDescriptorNode, 5))
Range.Flags = *Val;
else
- return reportError(Ctx, "Invalid value for Descriptor Range Flags");
+ return make_error<InvalidRSMetadataValue>("Descriptor Range Flags");
Table.Ranges.push_back(Range);
- return false;
+ return Error::success();
}
-bool MetadataParser::parseDescriptorTable(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *DescriptorTableNode) {
+Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *DescriptorTableNode) {
const unsigned int NumOperands = DescriptorTableNode->getNumOperands();
if (NumOperands < 2)
- return reportError(Ctx, "Invalid format for Descriptor Table");
+ return make_error<InvalidRSMetadataFormat>("Descriptor Table");
- dxbc::RTS0::v1::RootParameterHeader Header;
- if (std::optional<uint32_t> Val = extractMdIntValue(DescriptorTableNode, 1))
- Header.ShaderVisibility = *Val;
- else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ Expected<dxbc::ShaderVisibility> Visibility =
+ extractShaderVisibility(DescriptorTableNode, 1);
+ if (auto E = Visibility.takeError())
+ return Error(std::move(E));
mcdxbc::DescriptorTable Table;
- Header.ParameterType =
- llvm::to_underlying(dxbc::RootParameterType::DescriptorTable);
for (unsigned int I = 2; I < NumOperands; I++) {
MDNode *Element = dyn_cast<MDNode>(DescriptorTableNode->getOperand(I));
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.", DescriptorTableNode);
- if (parseDescriptorRange(Ctx, Table, Element))
- return true;
+ if (auto Err = parseDescriptorRange(Table, Element))
+ return Err;
}
- RSD.ParametersContainer.addParameter(Header, Table);
- return false;
+ RSD.ParametersContainer.addParameter(dxbc::RootParameterType::DescriptorTable,
+ *Visibility, Table);
+ return Error::success();
}
-bool MetadataParser::parseStaticSampler(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *StaticSamplerNode) {
+Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *StaticSamplerNode) {
if (StaticSamplerNode->getNumOperands() != 14)
- return reportError(Ctx, "Invalid format for Static Sampler");
+ return make_error<InvalidRSMetadataFormat>("Static Sampler");
dxbc::RTS0::v1::StaticSampler Sampler;
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 1))
Sampler.Filter = *Val;
else
- return reportError(Ctx, "Invalid value for Filter");
+ return make_error<InvalidRSMetadataValue>("Filter");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 2))
Sampler.AddressU = *Val;
else
- return reportError(Ctx, "Invalid value for AddressU");
+ return make_error<InvalidRSMetadataValue>("AddressU");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 3))
Sampler.AddressV = *Val;
else
- return reportError(Ctx, "Invalid value for AddressV");
+ return make_error<InvalidRSMetadataValue>("AddressV");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 4))
Sampler.AddressW = *Val;
else
- return reportError(Ctx, "Invalid value for AddressW");
+ return make_error<InvalidRSMetadataValue>("AddressW");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 5))
Sampler.MipLODBias = *Val;
else
- return reportError(Ctx, "Invalid value for MipLODBias");
+ return make_error<InvalidRSMetadataValue>("MipLODBias");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 6))
Sampler.MaxAnisotropy = *Val;
else
- return reportError(Ctx, "Invalid value for MaxAnisotropy");
+ return make_error<InvalidRSMetadataValue>("MaxAnisotropy");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 7))
Sampler.ComparisonFunc = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 8))
Sampler.BorderColor = *Val;
else
- return reportError(Ctx, "Invalid value for ComparisonFunc ");
+ return make_error<InvalidRSMetadataValue>("ComparisonFunc");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 9))
Sampler.MinLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MinLOD");
+ return make_error<InvalidRSMetadataValue>("MinLOD");
if (std::optional<float> Val = extractMdFloatValue(StaticSamplerNode, 10))
Sampler.MaxLOD = *Val;
else
- return reportError(Ctx, "Invalid value for MaxLOD");
+ return make_error<InvalidRSMetadataValue>("MaxLOD");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 11))
Sampler.ShaderRegister = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderRegister");
+ return make_error<InvalidRSMetadataValue>("ShaderRegister");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 12))
Sampler.RegisterSpace = *Val;
else
- return reportError(Ctx, "Invalid value for RegisterSpace");
+ return make_error<InvalidRSMetadataValue>("RegisterSpace");
if (std::optional<uint32_t> Val = extractMdIntValue(StaticSamplerNode, 13))
Sampler.ShaderVisibility = *Val;
else
- return reportError(Ctx, "Invalid value for ShaderVisibility");
+ return make_error<InvalidRSMetadataValue>("ShaderVisibility");
RSD.StaticSamplers.push_back(Sampler);
- return false;
+ return Error::success();
}
-bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD,
- MDNode *Element) {
+Error MetadataParser::parseRootSignatureElement(mcdxbc::RootSignatureDesc &RSD,
+ MDNode *Element) {
std::optional<StringRef> ElementText = extractMdStringValue(Element, 0);
if (!ElementText.has_value())
- return reportError(Ctx, "Invalid format for Root Element");
+ return make_error<InvalidRSMetadataFormat>("Root Element");
RootSignatureElementKind ElementKind =
StringSwitch<RootSignatureElementKind>(*ElementText)
@@ -523,79 +496,103 @@ bool MetadataParser::parseRootSignatureElement(LLVMContext *Ctx,
switch (ElementKind) {
case RootSignatureElementKind::RootFlags:
- return parseRootFlags(Ctx, RSD, Element);
+ return parseRootFlags(RSD, Element);
case RootSignatureElementKind::RootConstants:
- return parseRootConstants(Ctx, RSD, Element);
+ return parseRootConstants(RSD, Element);
case RootSignatureElementKind::CBV:
case RootSignatureElementKind::SRV:
case RootSignatureElementKind::UAV:
- return parseRootDescriptors(Ctx, RSD, Element, ElementKind);
+ return parseRootDescriptors(RSD, Element, ElementKind);
case RootSignatureElementKind::DescriptorTable:
- return parseDescriptorTable(Ctx, RSD, Element);
+ return parseDescriptorTable(RSD, Element);
case RootSignatureElementKind::StaticSamplers:
- return parseStaticSampler(Ctx, RSD, Element);
+ return parseStaticSampler(RSD, Element);
case RootSignatureElementKind::Error:
- return reportError(Ctx, "Invalid Root Signature Element: " + *ElementText);
+ return make_error<GenericRSMetadataError>("Invalid Root Signature Element",
+ Element);
}
llvm_unreachable("Unhandled RootSignatureElementKind enum.");
}
-bool MetadataParser::validateRootSignature(
- LLVMContext *Ctx, const llvm::mcdxbc::RootSignatureDesc &RSD) {
- if (!llvm::hlsl::rootsig::verifyVersion(RSD.Version)) {
- return reportValueError(Ctx, "Version", RSD.Version);
+Error MetadataParser::validateRootSignature(
+ const mcdxbc::RootSignatureDesc &RSD) {
+ Error DeferredErrs = Error::success();
+ if (!hlsl::rootsig::verifyVersion(RSD.Version)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "Version", RSD.Version));
}
- if (!llvm::hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
- return reportValueError(Ctx, "RootFlags", RSD.Flags);
+ if (!hlsl::rootsig::verifyRootFlag(RSD.Flags)) {
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RootFlags", RSD.Flags));
}
for (const mcdxbc::RootParameterInfo &Info : RSD.ParametersContainer) {
- if (!dxbc::isValidShaderVisibility(Info.Header.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Info.Header.ShaderVisibility);
-
- assert(dxbc::isValidParameterType(Info.Header.ParameterType) &&
- "Invalid value for ParameterType");
- switch (Info.Header.ParameterType) {
+ switch (Info.Type) {
+ case dxbc::RootParameterType::Constants32Bit:
+ break;
- case llvm::to_underlying(dxbc::RootParameterType::CBV):
- case llvm::to_underlying(dxbc::RootParameterType::UAV):
- case llvm::to_underlying(dxbc::RootParameterType::SRV): {
- const dxbc::RTS0::v2::RootDescriptor &Descriptor =
+ case dxbc::RootParameterType::CBV:
+ case dxbc::RootParameterType::UAV:
+ case dxbc::RootParameterType::SRV: {
+ const mcdxbc::RootDescriptor &Descriptor =
RSD.ParametersContainer.getRootDescriptor(Info.Location);
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister",
- Descriptor.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Descriptor.RegisterSpace);
+ if (!hlsl::rootsig::verifyRegisterValue(Descriptor.ShaderRegister))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Descriptor.ShaderRegister));
+
+ if (!hlsl::rootsig::verifyRegisterSpace(Descriptor.RegisterSpace))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Descriptor.RegisterSpace));
if (RSD.Version > 1) {
- if (!llvm::hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
- Descriptor.Flags))
- return reportValueError(Ctx, "RootDescriptorFlag", Descriptor.Flags);
+ if (!hlsl::rootsig::verifyRootDescriptorFlag(RSD.Version,
+ Descriptor.Flags))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RootDescriptorFlag", Descriptor.Flags));
}
break;
}
- case llvm::to_underlying(dxbc::RootParameterType::DescriptorTable): {
+ case dxbc::RootParameterType::DescriptorTable: {
const mcdxbc::DescriptorTable &Table =
RSD.ParametersContainer.getDescriptorTable(Info.Location);
for (const dxbc::RTS0::v2::DescriptorRange &Range : Table) {
- if (!llvm::hlsl::rootsig::verifyRangeType(Range.RangeType))
- return reportValueError(Ctx, "RangeType", Range.RangeType);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Range.RegisterSpace);
-
- if (!llvm::hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
- return reportValueError(Ctx, "NumDescriptors", Range.NumDescriptors);
-
- if (!llvm::hlsl::rootsig::verifyDescriptorRangeFlag(
+ if (!hlsl::rootsig::verifyRangeType(Range.RangeType))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RangeType", Range.RangeType));
+
+ if (!hlsl::rootsig::verifyRegisterSpace(Range.RegisterSpace))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Range.RegisterSpace));
+
+ if (!hlsl::rootsig::verifyNumDescriptors(Range.NumDescriptors))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "NumDescriptors", Range.NumDescriptors));
+
+ if (!hlsl::rootsig::verifyDescriptorRangeFlag(
RSD.Version, Range.RangeType, Range.Flags))
- return reportValueError(Ctx, "DescriptorFlag", Range.Flags);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "DescriptorFlag", Range.Flags));
}
break;
}
@@ -603,65 +600,108 @@ bool MetadataParser::validateRootSignature(
}
for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers) {
- if (!llvm::hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
- return reportValueError(Ctx, "Filter", Sampler.Filter);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressU))
- return reportValueError(Ctx, "AddressU", Sampler.AddressU);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressV))
- return reportValueError(Ctx, "AddressV", Sampler.AddressV);
-
- if (!llvm::hlsl::rootsig::verifyAddress(Sampler.AddressW))
- return reportValueError(Ctx, "AddressW", Sampler.AddressW);
-
- if (!llvm::hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
- return reportValueError(Ctx, "MipLODBias", Sampler.MipLODBias);
-
- if (!llvm::hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
- return reportValueError(Ctx, "MaxAnisotropy", Sampler.MaxAnisotropy);
-
- if (!llvm::hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
- return reportValueError(Ctx, "ComparisonFunc", Sampler.ComparisonFunc);
-
- if (!llvm::hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
- return reportValueError(Ctx, "BorderColor", Sampler.BorderColor);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MinLOD))
- return reportValueError(Ctx, "MinLOD", Sampler.MinLOD);
-
- if (!llvm::hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
- return reportValueError(Ctx, "MaxLOD", Sampler.MaxLOD);
-
- if (!llvm::hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
- return reportValueError(Ctx, "ShaderRegister", Sampler.ShaderRegister);
-
- if (!llvm::hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
- return reportValueError(Ctx, "RegisterSpace", Sampler.RegisterSpace);
+ if (!hlsl::rootsig::verifySamplerFilter(Sampler.Filter))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "Filter", Sampler.Filter));
+
+ if (!hlsl::rootsig::verifyAddress(Sampler.AddressU))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressU", Sampler.AddressU));
+
+ if (!hlsl::rootsig::verifyAddress(Sampler.AddressV))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressV", Sampler.AddressV));
+
+ if (!hlsl::rootsig::verifyAddress(Sampler.AddressW))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "AddressW", Sampler.AddressW));
+
+ if (!hlsl::rootsig::verifyMipLODBias(Sampler.MipLODBias))
+ DeferredErrs = joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<float>>(
+ "MipLODBias", Sampler.MipLODBias));
+
+ if (!hlsl::rootsig::verifyMaxAnisotropy(Sampler.MaxAnisotropy))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "MaxAnisotropy", Sampler.MaxAnisotropy));
+
+ if (!hlsl::rootsig::verifyComparisonFunc(Sampler.ComparisonFunc))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "ComparisonFunc", Sampler.ComparisonFunc));
+
+ if (!hlsl::rootsig::verifyBorderColor(Sampler.BorderColor))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "BorderColor", Sampler.BorderColor));
+
+ if (!hlsl::rootsig::verifyLOD(Sampler.MinLOD))
+ DeferredErrs = joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<float>>(
+ "MinLOD", Sampler.MinLOD));
+
+ if (!hlsl::rootsig::verifyLOD(Sampler.MaxLOD))
+ DeferredErrs = joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<float>>(
+ "MaxLOD", Sampler.MaxLOD));
+
+ if (!hlsl::rootsig::verifyRegisterValue(Sampler.ShaderRegister))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderRegister", Sampler.ShaderRegister));
+
+ if (!hlsl::rootsig::verifyRegisterSpace(Sampler.RegisterSpace))
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "RegisterSpace", Sampler.RegisterSpace));
if (!dxbc::isValidShaderVisibility(Sampler.ShaderVisibility))
- return reportValueError(Ctx, "ShaderVisibility",
- Sampler.ShaderVisibility);
+ DeferredErrs =
+ joinErrors(std::move(DeferredErrs),
+ make_error<RootSignatureValidationError<uint32_t>>(
+ "ShaderVisibility", Sampler.ShaderVisibility));
}
- return false;
+ return DeferredErrs;
}
-bool MetadataParser::ParseRootSignature(LLVMContext *Ctx,
- mcdxbc::RootSignatureDesc &RSD) {
- bool HasError = false;
-
- // Loop through the Root Elements of the root signature.
+Expected<mcdxbc::RootSignatureDesc>
+MetadataParser::ParseRootSignature(uint32_t Version) {
+ Error DeferredErrs = Error::success();
+ mcdxbc::RootSignatureDesc RSD;
+ RSD.Version = Version;
for (const auto &Operand : Root->operands()) {
MDNode *Element = dyn_cast<MDNode>(Operand);
if (Element == nullptr)
- return reportError(Ctx, "Missing Root Element Metadata Node.");
+ return joinErrors(std::move(DeferredErrs),
+ make_error<GenericRSMetadataError>(
+ "Missing Root Element Metadata Node.", nullptr));
- HasError = HasError || parseRootSignatureElement(Ctx, RSD, Element) ||
- validateRootSignature(Ctx, RSD);
+ if (auto Err = parseRootSignatureElement(RSD, Element))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
}
- return HasError;
+ if (auto Err = validateRootSignature(RSD))
+ DeferredErrs = joinErrors(std::move(DeferredErrs), std::move(Err));
+
+ if (DeferredErrs)
+ return std::move(DeferredErrs);
+
+ return std::move(RSD);
}
} // namespace rootsig
} // namespace hlsl
diff --git a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
index f11c7d2..72308a3d 100644
--- a/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
+++ b/llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
@@ -29,7 +29,7 @@ bool verifyRegisterValue(uint32_t RegisterValue) {
// This Range is reserverved, therefore invalid, according to the spec
// https://github.com/llvm/wg-hlsl/blob/main/proposals/0002-root-signature-in-clang.md#all-the-values-should-be-legal
bool verifyRegisterSpace(uint32_t RegisterSpace) {
- return !(RegisterSpace >= 0xFFFFFFF0 && RegisterSpace <= 0xFFFFFFFF);
+ return !(RegisterSpace >= 0xFFFFFFF0);
}
bool verifyRootDescriptorFlag(uint32_t Version, uint32_t FlagsVal) {
@@ -180,140 +180,6 @@ bool verifyBorderColor(uint32_t BorderColor) {
bool verifyLOD(float LOD) { return !std::isnan(LOD); }
-std::optional<const RangeInfo *>
-ResourceRange::getOverlapping(const RangeInfo &Info) const {
- MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
- if (!Interval.valid() || Info.UpperBound < Interval.start())
- return std::nullopt;
- return Interval.value();
-}
-
-const RangeInfo *ResourceRange::lookup(uint32_t X) const {
- return Intervals.lookup(X, nullptr);
-}
-
-void ResourceRange::clear() { return Intervals.clear(); }
-
-std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
- uint32_t LowerBound = Info.LowerBound;
- uint32_t UpperBound = Info.UpperBound;
-
- std::optional<const RangeInfo *> Res = std::nullopt;
- MapT::iterator Interval = Intervals.begin();
-
- while (true) {
- if (UpperBound < LowerBound)
- break;
-
- Interval.advanceTo(LowerBound);
- if (!Interval.valid()) // No interval found
- break;
-
- // Let Interval = [x;y] and [LowerBound;UpperBound] = [a;b] and note that
- // a <= y implicitly from Intervals.find(LowerBound)
- if (UpperBound < Interval.start())
- break; // found interval does not overlap with inserted one
-
- if (!Res.has_value()) // Update to be the first found intersection
- Res = Interval.value();
-
- if (Interval.start() <= LowerBound && UpperBound <= Interval.stop()) {
- // x <= a <= b <= y implies that [a;b] is covered by [x;y]
- // -> so we don't need to insert this, report an overlap
- return Res;
- } else if (LowerBound <= Interval.start() &&
- Interval.stop() <= UpperBound) {
- // a <= x <= y <= b implies that [x;y] is covered by [a;b]
- // -> so remove the existing interval that we will cover with the
- // overwrite
- Interval.erase();
- } else if (LowerBound < Interval.start() && UpperBound <= Interval.stop()) {
- // a < x <= b <= y implies that [a; x] is not covered but [x;b] is
- // -> so set b = x - 1 such that [a;x-1] is now the interval to insert
- UpperBound = Interval.start() - 1;
- } else if (Interval.start() <= LowerBound && Interval.stop() < UpperBound) {
- // a < x <= b <= y implies that [y; b] is not covered but [a;y] is
- // -> so set a = y + 1 such that [y+1;b] is now the interval to insert
- LowerBound = Interval.stop() + 1;
- }
- }
-
- assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
- Intervals.insert(LowerBound, UpperBound, &Info);
- return Res;
-}
-
-llvm::SmallVector<OverlappingRanges>
-findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
- // It is expected that Infos is filled with valid RangeInfos and that
- // they are sorted with respect to the RangeInfo <operator
- assert(llvm::is_sorted(Infos) && "Ranges must be sorted");
-
- llvm::SmallVector<OverlappingRanges> Overlaps;
- using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;
-
- // First we will init our state to track:
- if (Infos.size() == 0)
- return Overlaps; // No ranges to overlap
- GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
-
- // Create a ResourceRange for each Visibility
- ResourceRange::MapT::Allocator Allocator;
- std::array<ResourceRange, 8> Ranges = {
- ResourceRange(Allocator), // All
- ResourceRange(Allocator), // Vertex
- ResourceRange(Allocator), // Hull
- ResourceRange(Allocator), // Domain
- ResourceRange(Allocator), // Geometry
- ResourceRange(Allocator), // Pixel
- ResourceRange(Allocator), // Amplification
- ResourceRange(Allocator), // Mesh
- };
-
- // Reset the ResourceRanges for when we iterate through a new group
- auto ClearRanges = [&Ranges]() {
- for (ResourceRange &Range : Ranges)
- Range.clear();
- };
-
- // Iterate through collected RangeInfos
- for (const RangeInfo &Info : Infos) {
- GroupT InfoGroup = {Info.Class, Info.Space};
- // Reset our ResourceRanges when we enter a new group
- if (CurGroup != InfoGroup) {
- ClearRanges();
- CurGroup = InfoGroup;
- }
-
- // Insert range info into corresponding Visibility ResourceRange
- ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
- if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
- Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
-
- // Check for overlap in all overlapping Visibility ResourceRanges
- //
- // If the range that we are inserting has ShaderVisiblity::All it needs to
- // check for an overlap in all other visibility types as well.
- // Otherwise, the range that is inserted needs to check that it does not
- // overlap with ShaderVisibility::All.
- //
- // OverlapRanges will be an ArrayRef to all non-all visibility
- // ResourceRanges in the former case and it will be an ArrayRef to just the
- // all visiblity ResourceRange in the latter case.
- ArrayRef<ResourceRange> OverlapRanges =
- Info.Visibility == llvm::dxbc::ShaderVisibility::All
- ? ArrayRef<ResourceRange>{Ranges}.drop_front()
- : ArrayRef<ResourceRange>{Ranges}.take_front();
-
- for (const ResourceRange &Range : OverlapRanges)
- if (std::optional<const RangeInfo *> Overlapping =
- Range.getOverlapping(Info))
- Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
- }
-
- return Overlaps;
-}
-
} // namespace rootsig
} // namespace hlsl
} // namespace llvm